You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

231 lines
8.7KB

  1. import math
  2. import os
  3. import time
  4. import logging
  5. from copy import deepcopy
  6. import torch
  7. import torch.backends.cudnn as cudnn
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import torchvision.models as models
  11. logger = logging.getLogger(__name__)
  12. def init_seeds(seed=0):
  13. torch.manual_seed(seed)
  14. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  15. if seed == 0: # slower, more reproducible
  16. cudnn.deterministic = True
  17. cudnn.benchmark = False
  18. else: # faster, less reproducible
  19. cudnn.deterministic = False
  20. cudnn.benchmark = True
  21. def select_device(device='', batch_size=None):
  22. # device = 'cpu' or '0' or '0,1,2,3'
  23. cpu_request = device.lower() == 'cpu'
  24. if device and not cpu_request: # if device requested other than 'cpu'
  25. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  26. assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity
  27. cuda = False if cpu_request else torch.cuda.is_available()
  28. if cuda:
  29. c = 1024 ** 2 # bytes to MB
  30. ng = torch.cuda.device_count()
  31. if ng > 1 and batch_size: # check that batch_size is compatible with device_count
  32. assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
  33. x = [torch.cuda.get_device_properties(i) for i in range(ng)]
  34. s = 'Using CUDA '
  35. for i in range(0, ng):
  36. if i == 1:
  37. s = ' ' * len(s)
  38. logger.info("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
  39. (s, i, x[i].name, x[i].total_memory / c))
  40. else:
  41. logger.info('Using CPU')
  42. logger.info('') # skip a line
  43. return torch.device('cuda:0' if cuda else 'cpu')
  44. def time_synchronized():
  45. torch.cuda.synchronize() if torch.cuda.is_available() else None
  46. return time.time()
  47. def is_parallel(model):
  48. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  49. def intersect_dicts(da, db, exclude=()):
  50. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  51. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  52. def initialize_weights(model):
  53. for m in model.modules():
  54. t = type(m)
  55. if t is nn.Conv2d:
  56. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  57. elif t is nn.BatchNorm2d:
  58. m.eps = 1e-3
  59. m.momentum = 0.03
  60. elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  61. m.inplace = True
  62. def find_modules(model, mclass=nn.Conv2d):
  63. # Finds layer indices matching module class 'mclass'
  64. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  65. def sparsity(model):
  66. # Return global model sparsity
  67. a, b = 0., 0.
  68. for p in model.parameters():
  69. a += p.numel()
  70. b += (p == 0).sum()
  71. return b / a
  72. def prune(model, amount=0.3):
  73. # Prune model to requested global sparsity
  74. import torch.nn.utils.prune as prune
  75. print('Pruning model... ', end='')
  76. for name, m in model.named_modules():
  77. if isinstance(m, nn.Conv2d):
  78. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  79. prune.remove(m, 'weight') # make permanent
  80. print(' %.3g global sparsity' % sparsity(model))
  81. def fuse_conv_and_bn(conv, bn):
  82. # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  83. with torch.no_grad():
  84. # init
  85. fusedconv = nn.Conv2d(conv.in_channels,
  86. conv.out_channels,
  87. kernel_size=conv.kernel_size,
  88. stride=conv.stride,
  89. padding=conv.padding,
  90. bias=True).to(conv.weight.device)
  91. # prepare filters
  92. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  93. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  94. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
  95. # prepare spatial bias
  96. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  97. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  98. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  99. return fusedconv
  100. def model_info(model, verbose=False):
  101. # Plots a line-by-line description of a PyTorch model
  102. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  103. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  104. if verbose:
  105. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  106. for i, (name, p) in enumerate(model.named_parameters()):
  107. name = name.replace('module_list.', '')
  108. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  109. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  110. try: # FLOPS
  111. from thop import profile
  112. flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
  113. fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
  114. except:
  115. fs = ''
  116. logger.info(
  117. 'Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
  118. def load_classifier(name='resnet101', n=2):
  119. # Loads a pretrained model reshaped to n-class output
  120. model = models.__dict__[name](pretrained=True)
  121. # Display model properties
  122. input_size = [3, 224, 224]
  123. input_space = 'RGB'
  124. input_range = [0, 1]
  125. mean = [0.485, 0.456, 0.406]
  126. std = [0.229, 0.224, 0.225]
  127. for x in ['input_size', 'input_space', 'input_range', 'mean', 'std']:
  128. print(x + ' =', eval(x))
  129. # Reshape output to n classes
  130. filters = model.fc.weight.shape[1]
  131. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  132. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  133. model.fc.out_features = n
  134. return model
  135. def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
  136. # scales img(bs,3,y,x) by ratio
  137. if ratio == 1.0:
  138. return img
  139. else:
  140. h, w = img.shape[2:]
  141. s = (int(h * ratio), int(w * ratio)) # new size
  142. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  143. if not same_shape: # pad/crop img
  144. gs = 32 # (pixels) grid size
  145. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  146. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  147. def copy_attr(a, b, include=(), exclude=()):
  148. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  149. for k, v in b.__dict__.items():
  150. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  151. continue
  152. else:
  153. setattr(a, k, v)
  154. class ModelEMA:
  155. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  156. Keep a moving average of everything in the model state_dict (parameters and buffers).
  157. This is intended to allow functionality like
  158. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  159. A smoothed version of the weights is necessary for some training schemes to perform well.
  160. This class is sensitive where it is initialized in the sequence of model init,
  161. GPU assignment and distributed training wrappers.
  162. """
  163. def __init__(self, model, decay=0.9999, updates=0):
  164. # Create EMA
  165. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  166. # if next(model.parameters()).device.type != 'cpu':
  167. # self.ema.half() # FP16 EMA
  168. self.updates = updates # number of EMA updates
  169. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  170. for p in self.ema.parameters():
  171. p.requires_grad_(False)
  172. def update(self, model):
  173. # Update EMA parameters
  174. with torch.no_grad():
  175. self.updates += 1
  176. d = self.decay(self.updates)
  177. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  178. for k, v in self.ema.state_dict().items():
  179. if v.dtype.is_floating_point:
  180. v *= d
  181. v += (1. - d) * msd[k].detach()
  182. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  183. # Update EMA attributes
  184. copy_attr(self.ema, model, include, exclude)