You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

206 lines
8.4KB

  1. import math
  2. import os
  3. import time
  4. from copy import deepcopy
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torchvision.models as models
  10. def init_seeds(seed=0):
  11. torch.manual_seed(seed)
  12. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  13. if seed == 0: # slower, more reproducible
  14. cudnn.deterministic = True
  15. cudnn.benchmark = False
  16. else: # faster, less reproducible
  17. cudnn.deterministic = False
  18. cudnn.benchmark = True
  19. def select_device(device='', apex=False, batch_size=None):
  20. # device = 'cpu' or '0' or '0,1,2,3'
  21. cpu_request = device.lower() == 'cpu'
  22. if device and not cpu_request: # if device requested other than 'cpu'
  23. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  24. assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity
  25. cuda = False if cpu_request else torch.cuda.is_available()
  26. if cuda:
  27. c = 1024 ** 2 # bytes to MB
  28. ng = torch.cuda.device_count()
  29. if ng > 1 and batch_size: # check that batch_size is compatible with device_count
  30. assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
  31. x = [torch.cuda.get_device_properties(i) for i in range(ng)]
  32. s = 'Using CUDA ' + ('Apex ' if apex else '') # apex for mixed precision https://github.com/NVIDIA/apex
  33. for i in range(0, ng):
  34. if i == 1:
  35. s = ' ' * len(s)
  36. print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
  37. (s, i, x[i].name, x[i].total_memory / c))
  38. else:
  39. print('Using CPU')
  40. print('') # skip a line
  41. return torch.device('cuda:0' if cuda else 'cpu')
  42. def time_synchronized():
  43. torch.cuda.synchronize() if torch.cuda.is_available() else None
  44. return time.time()
  45. def is_parallel(model):
  46. # is model is parallel with DP or DDP
  47. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  48. def initialize_weights(model):
  49. for m in model.modules():
  50. t = type(m)
  51. if t is nn.Conv2d:
  52. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  53. elif t is nn.BatchNorm2d:
  54. m.eps = 1e-4
  55. m.momentum = 0.03
  56. elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  57. m.inplace = True
  58. def find_modules(model, mclass=nn.Conv2d):
  59. # finds layer indices matching module class 'mclass'
  60. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  61. def fuse_conv_and_bn(conv, bn):
  62. # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  63. with torch.no_grad():
  64. # init
  65. fusedconv = torch.nn.Conv2d(conv.in_channels,
  66. conv.out_channels,
  67. kernel_size=conv.kernel_size,
  68. stride=conv.stride,
  69. padding=conv.padding,
  70. bias=True)
  71. # prepare filters
  72. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  73. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  74. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
  75. # prepare spatial bias
  76. if conv.bias is not None:
  77. b_conv = conv.bias
  78. else:
  79. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device)
  80. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  81. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  82. return fusedconv
  83. def model_info(model, verbose=False):
  84. # Plots a line-by-line description of a PyTorch model
  85. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  86. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  87. if verbose:
  88. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  89. for i, (name, p) in enumerate(model.named_parameters()):
  90. name = name.replace('module_list.', '')
  91. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  92. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  93. try: # FLOPS
  94. from thop import profile
  95. flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
  96. fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
  97. except:
  98. fs = ''
  99. print('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
  100. def load_classifier(name='resnet101', n=2):
  101. # Loads a pretrained model reshaped to n-class output
  102. model = models.__dict__[name](pretrained=True)
  103. # Display model properties
  104. input_size = [3, 224, 224]
  105. input_space = 'RGB'
  106. input_range = [0, 1]
  107. mean = [0.485, 0.456, 0.406]
  108. std = [0.229, 0.224, 0.225]
  109. for x in [input_size, input_space, input_range, mean, std]:
  110. print(x + ' =', eval(x))
  111. # Reshape output to n classes
  112. filters = model.fc.weight.shape[1]
  113. model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True)
  114. model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  115. model.fc.out_features = n
  116. return model
  117. def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
  118. # scales img(bs,3,y,x) by ratio
  119. h, w = img.shape[2:]
  120. s = (int(h * ratio), int(w * ratio)) # new size
  121. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  122. if not same_shape: # pad/crop img
  123. gs = 32 # (pixels) grid size
  124. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  125. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  126. class ModelEMA:
  127. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  128. Keep a moving average of everything in the model state_dict (parameters and buffers).
  129. This is intended to allow functionality like
  130. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  131. A smoothed version of the weights is necessary for some training schemes to perform well.
  132. E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
  133. RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
  134. smoothing of weights to match results. Pay attention to the decay constant you are using
  135. relative to your update count per epoch.
  136. To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
  137. disable validation of the EMA weights. Validation will have to be done manually in a separate
  138. process, or after the training stops converging.
  139. This class is sensitive where it is initialized in the sequence of model init,
  140. GPU assignment and distributed training wrappers.
  141. I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
  142. """
  143. def __init__(self, model, decay=0.9999, device=''):
  144. # Create EMA
  145. self.ema = deepcopy(model.module if is_parallel(model) else model) # FP32 EMA
  146. self.ema.eval()
  147. self.updates = 0 # number of EMA updates
  148. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  149. self.device = device # perform ema on different device from model if set
  150. if device:
  151. self.ema.to(device)
  152. for p in self.ema.parameters():
  153. p.requires_grad_(False)
  154. def update(self, model):
  155. # Update EMA parameters
  156. with torch.no_grad():
  157. self.updates += 1
  158. d = self.decay(self.updates)
  159. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  160. for k, v in self.ema.state_dict().items():
  161. if v.dtype.is_floating_point:
  162. v *= d
  163. v += (1. - d) * msd[k].detach()
  164. def update_attr(self, model):
  165. # Update EMA attributes
  166. for k, v in model.__dict__.items():
  167. if not k.startswith('_') and k not in ["process_group", "reducer"]:
  168. setattr(self.ema, k, v)