You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

230 lines
8.7KB

  1. import logging
  2. import math
  3. import os
  4. import time
  5. from copy import deepcopy
  6. import torch
  7. import torch.backends.cudnn as cudnn
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. logger = logging.getLogger(__name__)
  11. def init_torch_seeds(seed=0):
  12. torch.manual_seed(seed)
  13. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  14. if seed == 0: # slower, more reproducible
  15. cudnn.deterministic = True
  16. cudnn.benchmark = False
  17. else: # faster, less reproducible
  18. cudnn.deterministic = False
  19. cudnn.benchmark = True
  20. def select_device(device='', batch_size=None):
  21. # device = 'cpu' or '0' or '0,1,2,3'
  22. cpu_request = device.lower() == 'cpu'
  23. if device and not cpu_request: # if device requested other than 'cpu'
  24. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  25. assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity
  26. cuda = False if cpu_request else torch.cuda.is_available()
  27. if cuda:
  28. c = 1024 ** 2 # bytes to MB
  29. ng = torch.cuda.device_count()
  30. if ng > 1 and batch_size: # check that batch_size is compatible with device_count
  31. assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
  32. x = [torch.cuda.get_device_properties(i) for i in range(ng)]
  33. s = 'Using CUDA '
  34. for i in range(0, ng):
  35. if i == 1:
  36. s = ' ' * len(s)
  37. logger.info("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
  38. (s, i, x[i].name, x[i].total_memory / c))
  39. else:
  40. logger.info('Using CPU')
  41. logger.info('') # skip a line
  42. return torch.device('cuda:0' if cuda else 'cpu')
  43. def time_synchronized():
  44. torch.cuda.synchronize() if torch.cuda.is_available() else None
  45. return time.time()
  46. def is_parallel(model):
  47. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  48. def intersect_dicts(da, db, exclude=()):
  49. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  50. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  51. def initialize_weights(model):
  52. for m in model.modules():
  53. t = type(m)
  54. if t is nn.Conv2d:
  55. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  56. elif t is nn.BatchNorm2d:
  57. m.eps = 1e-3
  58. m.momentum = 0.03
  59. elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  60. m.inplace = True
  61. def find_modules(model, mclass=nn.Conv2d):
  62. # Finds layer indices matching module class 'mclass'
  63. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  64. def sparsity(model):
  65. # Return global model sparsity
  66. a, b = 0., 0.
  67. for p in model.parameters():
  68. a += p.numel()
  69. b += (p == 0).sum()
  70. return b / a
  71. def prune(model, amount=0.3):
  72. # Prune model to requested global sparsity
  73. import torch.nn.utils.prune as prune
  74. print('Pruning model... ', end='')
  75. for name, m in model.named_modules():
  76. if isinstance(m, nn.Conv2d):
  77. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  78. prune.remove(m, 'weight') # make permanent
  79. print(' %.3g global sparsity' % sparsity(model))
  80. def fuse_conv_and_bn(conv, bn):
  81. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  82. # init
  83. fusedconv = nn.Conv2d(conv.in_channels,
  84. conv.out_channels,
  85. kernel_size=conv.kernel_size,
  86. stride=conv.stride,
  87. padding=conv.padding,
  88. groups=conv.groups,
  89. bias=True).requires_grad_(False).to(conv.weight.device)
  90. # prepare filters
  91. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  92. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  93. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
  94. # prepare spatial bias
  95. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  96. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  97. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  98. return fusedconv
  99. def model_info(model, verbose=False):
  100. # Plots a line-by-line description of a PyTorch model
  101. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  102. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  103. if verbose:
  104. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  105. for i, (name, p) in enumerate(model.named_parameters()):
  106. name = name.replace('module_list.', '')
  107. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  108. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  109. try: # FLOPS
  110. from thop import profile
  111. flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
  112. fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
  113. except:
  114. fs = ''
  115. logger.info(
  116. 'Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
  117. def load_classifier(name='resnet101', n=2):
  118. # Loads a pretrained model reshaped to n-class output
  119. import torchvision
  120. model = torchvision.models.__dict__[name](pretrained=True)
  121. # ResNet model properties
  122. # input_size = [3, 224, 224]
  123. # input_space = 'RGB'
  124. # input_range = [0, 1]
  125. # mean = [0.485, 0.456, 0.406]
  126. # std = [0.229, 0.224, 0.225]
  127. # Reshape output to n classes
  128. filters = model.fc.weight.shape[1]
  129. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  130. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  131. model.fc.out_features = n
  132. return model
  133. def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
  134. # scales img(bs,3,y,x) by ratio
  135. if ratio == 1.0:
  136. return img
  137. else:
  138. h, w = img.shape[2:]
  139. s = (int(h * ratio), int(w * ratio)) # new size
  140. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  141. if not same_shape: # pad/crop img
  142. gs = 32 # (pixels) grid size
  143. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  144. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  145. def copy_attr(a, b, include=(), exclude=()):
  146. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  147. for k, v in b.__dict__.items():
  148. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  149. continue
  150. else:
  151. setattr(a, k, v)
  152. class ModelEMA:
  153. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  154. Keep a moving average of everything in the model state_dict (parameters and buffers).
  155. This is intended to allow functionality like
  156. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  157. A smoothed version of the weights is necessary for some training schemes to perform well.
  158. This class is sensitive where it is initialized in the sequence of model init,
  159. GPU assignment and distributed training wrappers.
  160. """
  161. def __init__(self, model, decay=0.9999, updates=0):
  162. # Create EMA
  163. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  164. # if next(model.parameters()).device.type != 'cpu':
  165. # self.ema.half() # FP16 EMA
  166. self.updates = updates # number of EMA updates
  167. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  168. for p in self.ema.parameters():
  169. p.requires_grad_(False)
  170. def update(self, model):
  171. # Update EMA parameters
  172. with torch.no_grad():
  173. self.updates += 1
  174. d = self.decay(self.updates)
  175. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  176. for k, v in self.ema.state_dict().items():
  177. if v.dtype.is_floating_point:
  178. v *= d
  179. v += (1. - d) * msd[k].detach()
  180. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  181. # Update EMA attributes
  182. copy_attr(self.ema, model, include, exclude)