You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

230 lines
8.7KB

  1. import logging
  2. import math
  3. import os
  4. import time
  5. from copy import deepcopy
  6. import torch
  7. import torch.backends.cudnn as cudnn
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import torchvision
  11. logger = logging.getLogger(__name__)
  12. def init_torch_seeds(seed=0):
  13. torch.manual_seed(seed)
  14. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  15. if seed == 0: # slower, more reproducible
  16. cudnn.deterministic = True
  17. cudnn.benchmark = False
  18. else: # faster, less reproducible
  19. cudnn.deterministic = False
  20. cudnn.benchmark = True
  21. def select_device(device='', batch_size=None):
  22. # device = 'cpu' or '0' or '0,1,2,3'
  23. cpu_request = device.lower() == 'cpu'
  24. if device and not cpu_request: # if device requested other than 'cpu'
  25. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  26. assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity
  27. cuda = False if cpu_request else torch.cuda.is_available()
  28. if cuda:
  29. c = 1024 ** 2 # bytes to MB
  30. ng = torch.cuda.device_count()
  31. if ng > 1 and batch_size: # check that batch_size is compatible with device_count
  32. assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
  33. x = [torch.cuda.get_device_properties(i) for i in range(ng)]
  34. s = 'Using CUDA '
  35. for i in range(0, ng):
  36. if i == 1:
  37. s = ' ' * len(s)
  38. logger.info("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
  39. (s, i, x[i].name, x[i].total_memory / c))
  40. else:
  41. logger.info('Using CPU')
  42. logger.info('') # skip a line
  43. return torch.device('cuda:0' if cuda else 'cpu')
  44. def time_synchronized():
  45. torch.cuda.synchronize() if torch.cuda.is_available() else None
  46. return time.time()
  47. def is_parallel(model):
  48. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  49. def intersect_dicts(da, db, exclude=()):
  50. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  51. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  52. def initialize_weights(model):
  53. for m in model.modules():
  54. t = type(m)
  55. if t is nn.Conv2d:
  56. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  57. elif t is nn.BatchNorm2d:
  58. m.eps = 1e-3
  59. m.momentum = 0.03
  60. elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  61. m.inplace = True
  62. def find_modules(model, mclass=nn.Conv2d):
  63. # Finds layer indices matching module class 'mclass'
  64. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  65. def sparsity(model):
  66. # Return global model sparsity
  67. a, b = 0., 0.
  68. for p in model.parameters():
  69. a += p.numel()
  70. b += (p == 0).sum()
  71. return b / a
  72. def prune(model, amount=0.3):
  73. # Prune model to requested global sparsity
  74. import torch.nn.utils.prune as prune
  75. print('Pruning model... ', end='')
  76. for name, m in model.named_modules():
  77. if isinstance(m, nn.Conv2d):
  78. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  79. prune.remove(m, 'weight') # make permanent
  80. print(' %.3g global sparsity' % sparsity(model))
  81. def fuse_conv_and_bn(conv, bn):
  82. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  83. # init
  84. fusedconv = nn.Conv2d(conv.in_channels,
  85. conv.out_channels,
  86. kernel_size=conv.kernel_size,
  87. stride=conv.stride,
  88. padding=conv.padding,
  89. groups=conv.groups,
  90. bias=True).requires_grad_(False).to(conv.weight.device)
  91. # prepare filters
  92. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  93. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  94. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
  95. # prepare spatial bias
  96. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  97. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  98. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  99. return fusedconv
  100. def model_info(model, verbose=False):
  101. # Plots a line-by-line description of a PyTorch model
  102. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  103. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  104. if verbose:
  105. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  106. for i, (name, p) in enumerate(model.named_parameters()):
  107. name = name.replace('module_list.', '')
  108. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  109. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  110. try: # FLOPS
  111. from thop import profile
  112. flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
  113. fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
  114. except:
  115. fs = ''
  116. logger.info(
  117. 'Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
  118. def load_classifier(name='resnet101', n=2):
  119. # Loads a pretrained model reshaped to n-class output
  120. model = torchvision.models.__dict__[name](pretrained=True)
  121. # ResNet model properties
  122. # input_size = [3, 224, 224]
  123. # input_space = 'RGB'
  124. # input_range = [0, 1]
  125. # mean = [0.485, 0.456, 0.406]
  126. # std = [0.229, 0.224, 0.225]
  127. # Reshape output to n classes
  128. filters = model.fc.weight.shape[1]
  129. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  130. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  131. model.fc.out_features = n
  132. return model
  133. def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
  134. # scales img(bs,3,y,x) by ratio
  135. if ratio == 1.0:
  136. return img
  137. else:
  138. h, w = img.shape[2:]
  139. s = (int(h * ratio), int(w * ratio)) # new size
  140. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  141. if not same_shape: # pad/crop img
  142. gs = 32 # (pixels) grid size
  143. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  144. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  145. def copy_attr(a, b, include=(), exclude=()):
  146. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  147. for k, v in b.__dict__.items():
  148. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  149. continue
  150. else:
  151. setattr(a, k, v)
  152. class ModelEMA:
  153. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  154. Keep a moving average of everything in the model state_dict (parameters and buffers).
  155. This is intended to allow functionality like
  156. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  157. A smoothed version of the weights is necessary for some training schemes to perform well.
  158. This class is sensitive where it is initialized in the sequence of model init,
  159. GPU assignment and distributed training wrappers.
  160. """
  161. def __init__(self, model, decay=0.9999, updates=0):
  162. # Create EMA
  163. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  164. # if next(model.parameters()).device.type != 'cpu':
  165. # self.ema.half() # FP16 EMA
  166. self.updates = updates # number of EMA updates
  167. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  168. for p in self.ema.parameters():
  169. p.requires_grad_(False)
  170. def update(self, model):
  171. # Update EMA parameters
  172. with torch.no_grad():
  173. self.updates += 1
  174. d = self.decay(self.updates)
  175. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  176. for k, v in self.ema.state_dict().items():
  177. if v.dtype.is_floating_point:
  178. v *= d
  179. v += (1. - d) * msd[k].detach()
  180. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  181. # Update EMA attributes
  182. copy_attr(self.ema, model, include, exclude)