You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

223 lines
8.4KB

  1. import math
  2. import os
  3. import time
  4. from copy import deepcopy
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torchvision.models as models
  10. def init_seeds(seed=0):
  11. torch.manual_seed(seed)
  12. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  13. if seed == 0: # slower, more reproducible
  14. cudnn.deterministic = True
  15. cudnn.benchmark = False
  16. else: # faster, less reproducible
  17. cudnn.deterministic = False
  18. cudnn.benchmark = True
  19. def select_device(device='', batch_size=None):
  20. # device = 'cpu' or '0' or '0,1,2,3'
  21. cpu_request = device.lower() == 'cpu'
  22. if device and not cpu_request: # if device requested other than 'cpu'
  23. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  24. assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity
  25. cuda = False if cpu_request else torch.cuda.is_available()
  26. if cuda:
  27. c = 1024 ** 2 # bytes to MB
  28. ng = torch.cuda.device_count()
  29. if ng > 1 and batch_size: # check that batch_size is compatible with device_count
  30. assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
  31. x = [torch.cuda.get_device_properties(i) for i in range(ng)]
  32. s = 'Using CUDA '
  33. for i in range(0, ng):
  34. if i == 1:
  35. s = ' ' * len(s)
  36. print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
  37. (s, i, x[i].name, x[i].total_memory / c))
  38. else:
  39. print('Using CPU')
  40. print('') # skip a line
  41. return torch.device('cuda:0' if cuda else 'cpu')
  42. def time_synchronized():
  43. torch.cuda.synchronize() if torch.cuda.is_available() else None
  44. return time.time()
  45. def is_parallel(model):
  46. # is model is parallel with DP or DDP
  47. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  48. def initialize_weights(model):
  49. for m in model.modules():
  50. t = type(m)
  51. if t is nn.Conv2d:
  52. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  53. elif t is nn.BatchNorm2d:
  54. m.eps = 1e-3
  55. m.momentum = 0.03
  56. elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  57. m.inplace = True
  58. def find_modules(model, mclass=nn.Conv2d):
  59. # finds layer indices matching module class 'mclass'
  60. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  61. def sparsity(model):
  62. # Return global model sparsity
  63. a, b = 0., 0.
  64. for p in model.parameters():
  65. a += p.numel()
  66. b += (p == 0).sum()
  67. return b / a
  68. def prune(model, amount=0.3):
  69. # Prune model to requested global sparsity
  70. import torch.nn.utils.prune as prune
  71. print('Pruning model... ', end='')
  72. for name, m in model.named_modules():
  73. if isinstance(m, nn.Conv2d):
  74. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  75. prune.remove(m, 'weight') # make permanent
  76. print(' %.3g global sparsity' % sparsity(model))
  77. def fuse_conv_and_bn(conv, bn):
  78. # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  79. with torch.no_grad():
  80. # init
  81. fusedconv = nn.Conv2d(conv.in_channels,
  82. conv.out_channels,
  83. kernel_size=conv.kernel_size,
  84. stride=conv.stride,
  85. padding=conv.padding,
  86. bias=True).to(conv.weight.device)
  87. # prepare filters
  88. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  89. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  90. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
  91. # prepare spatial bias
  92. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  93. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  94. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  95. return fusedconv
  96. def model_info(model, verbose=False):
  97. # Plots a line-by-line description of a PyTorch model
  98. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  99. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  100. if verbose:
  101. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  102. for i, (name, p) in enumerate(model.named_parameters()):
  103. name = name.replace('module_list.', '')
  104. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  105. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  106. try: # FLOPS
  107. from thop import profile
  108. flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
  109. fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
  110. except:
  111. fs = ''
  112. print('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
  113. def load_classifier(name='resnet101', n=2):
  114. # Loads a pretrained model reshaped to n-class output
  115. model = models.__dict__[name](pretrained=True)
  116. # Display model properties
  117. input_size = [3, 224, 224]
  118. input_space = 'RGB'
  119. input_range = [0, 1]
  120. mean = [0.485, 0.456, 0.406]
  121. std = [0.229, 0.224, 0.225]
  122. for x in [input_size, input_space, input_range, mean, std]:
  123. print(x + ' =', eval(x))
  124. # Reshape output to n classes
  125. filters = model.fc.weight.shape[1]
  126. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  127. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  128. model.fc.out_features = n
  129. return model
  130. def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
  131. # scales img(bs,3,y,x) by ratio
  132. if ratio == 1.0:
  133. return img
  134. else:
  135. h, w = img.shape[2:]
  136. s = (int(h * ratio), int(w * ratio)) # new size
  137. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  138. if not same_shape: # pad/crop img
  139. gs = 32 # (pixels) grid size
  140. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  141. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  142. def copy_attr(a, b, include=(), exclude=()):
  143. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  144. for k, v in b.__dict__.items():
  145. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  146. continue
  147. else:
  148. setattr(a, k, v)
  149. class ModelEMA:
  150. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  151. Keep a moving average of everything in the model state_dict (parameters and buffers).
  152. This is intended to allow functionality like
  153. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  154. A smoothed version of the weights is necessary for some training schemes to perform well.
  155. This class is sensitive where it is initialized in the sequence of model init,
  156. GPU assignment and distributed training wrappers.
  157. """
  158. def __init__(self, model, decay=0.9999, updates=0):
  159. # Create EMA
  160. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  161. # if next(model.parameters()).device.type != 'cpu':
  162. # self.ema.half() # FP16 EMA
  163. self.updates = updates # number of EMA updates
  164. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  165. for p in self.ema.parameters():
  166. p.requires_grad_(False)
  167. def update(self, model):
  168. # Update EMA parameters
  169. with torch.no_grad():
  170. self.updates += 1
  171. d = self.decay(self.updates)
  172. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  173. for k, v in self.ema.state_dict().items():
  174. if v.dtype.is_floating_point:
  175. v *= d
  176. v += (1. - d) * msd[k].detach()
  177. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  178. # Update EMA attributes
  179. copy_attr(self.ema, model, include, exclude)