用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

224 lines
8.7KB

  1. # Adopt from https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/nn/syncbn.py
  2. """Synchronized Cross-GPU Batch Normalization Module"""
  3. import warnings
  4. import torch
  5. import torch.cuda.comm as comm
  6. from queue import Queue
  7. from torch.autograd import Function
  8. from torch.nn.modules.batchnorm import _BatchNorm
  9. from torch.autograd.function import once_differentiable
  10. from core.nn import _C
  11. __all__ = ['SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']
  12. class _SyncBatchNorm(Function):
  13. @classmethod
  14. def forward(cls, ctx, x, gamma, beta, running_mean, running_var,
  15. extra, sync=True, training=True, momentum=0.1, eps=1e-05,
  16. activation="none", slope=0.01):
  17. # save context
  18. cls._parse_extra(ctx, extra)
  19. ctx.sync = sync
  20. ctx.training = training
  21. ctx.momentum = momentum
  22. ctx.eps = eps
  23. ctx.activation = activation
  24. ctx.slope = slope
  25. assert activation == 'none'
  26. # continous inputs
  27. x = x.contiguous()
  28. gamma = gamma.contiguous()
  29. beta = beta.contiguous()
  30. if ctx.training:
  31. _ex, _exs = _C.expectation_forward(x)
  32. if ctx.sync:
  33. if ctx.is_master:
  34. _ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
  35. for _ in range(ctx.master_queue.maxsize):
  36. _ex_w, _exs_w = ctx.master_queue.get()
  37. ctx.master_queue.task_done()
  38. _ex.append(_ex_w.unsqueeze(0))
  39. _exs.append(_exs_w.unsqueeze(0))
  40. _ex = comm.gather(_ex).mean(0)
  41. _exs = comm.gather(_exs).mean(0)
  42. tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
  43. for ts, queue in zip(tensors[1:], ctx.worker_queues):
  44. queue.put(ts)
  45. else:
  46. ctx.master_queue.put((_ex, _exs))
  47. _ex, _exs = ctx.worker_queue.get()
  48. ctx.worker_queue.task_done()
  49. # Update running stats
  50. _var = _exs - _ex ** 2
  51. running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
  52. running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)
  53. # Mark in-place modified tensors
  54. ctx.mark_dirty(running_mean, running_var)
  55. else:
  56. _ex, _var = running_mean.contiguous(), running_var.contiguous()
  57. _exs = _var + _ex ** 2
  58. # BN forward
  59. y = _C.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
  60. # Output
  61. ctx.save_for_backward(x, _ex, _exs, gamma, beta)
  62. return y
  63. @staticmethod
  64. @once_differentiable
  65. def backward(ctx, dz):
  66. x, _ex, _exs, gamma, beta = ctx.saved_tensors
  67. dz = dz.contiguous()
  68. # BN backward
  69. dx, _dex, _dexs, dgamma, dbeta = _C.batchnorm_backward(dz, x, _ex, _exs, gamma, beta, ctx.eps)
  70. if ctx.training:
  71. if ctx.sync:
  72. if ctx.is_master:
  73. _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
  74. for _ in range(ctx.master_queue.maxsize):
  75. _dex_w, _dexs_w = ctx.master_queue.get()
  76. ctx.master_queue.task_done()
  77. _dex.append(_dex_w.unsqueeze(0))
  78. _dexs.append(_dexs_w.unsqueeze(0))
  79. _dex = comm.gather(_dex).mean(0)
  80. _dexs = comm.gather(_dexs).mean(0)
  81. tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
  82. for ts, queue in zip(tensors[1:], ctx.worker_queues):
  83. queue.put(ts)
  84. else:
  85. ctx.master_queue.put((_dex, _dexs))
  86. _dex, _dexs = ctx.worker_queue.get()
  87. ctx.worker_queue.task_done()
  88. dx_ = _C.expectation_backward(x, _dex, _dexs)
  89. dx = dx + dx_
  90. return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
  91. @staticmethod
  92. def _parse_extra(ctx, extra):
  93. ctx.is_master = extra["is_master"]
  94. if ctx.is_master:
  95. ctx.master_queue = extra["master_queue"]
  96. ctx.worker_queues = extra["worker_queues"]
  97. ctx.worker_ids = extra["worker_ids"]
  98. else:
  99. ctx.master_queue = extra["master_queue"]
  100. ctx.worker_queue = extra["worker_queue"]
  101. syncbatchnorm = _SyncBatchNorm.apply
  102. class SyncBatchNorm(_BatchNorm):
  103. """Cross-GPU Synchronized Batch normalization (SyncBN)
  104. Parameters:
  105. num_features: num_features from an expected input of
  106. size batch_size x num_features x height x width
  107. eps: a value added to the denominator for numerical stability.
  108. Default: 1e-5
  109. momentum: the value used for the running_mean and running_var
  110. computation. Default: 0.1
  111. sync: a boolean value that when set to ``True``, synchronize across
  112. different gpus. Default: ``True``
  113. activation : str
  114. Name of the activation functions, one of: `leaky_relu` or `none`.
  115. slope : float
  116. Negative slope for the `leaky_relu` activation.
  117. Shape:
  118. - Input: :math:`(N, C, H, W)`
  119. - Output: :math:`(N, C, H, W)` (same shape as input)
  120. Reference:
  121. .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015*
  122. .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
  123. Examples:
  124. >>> m = SyncBatchNorm(100)
  125. >>> net = torch.nn.DataParallel(m)
  126. >>> output = net(input)
  127. """
  128. def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation='none', slope=0.01):
  129. super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True)
  130. self.activation = activation
  131. self.slope = slope
  132. self.devices = list(range(torch.cuda.device_count()))
  133. self.sync = sync if len(self.devices) > 1 else False
  134. # Initialize queues
  135. self.worker_ids = self.devices[1:]
  136. self.master_queue = Queue(len(self.worker_ids))
  137. self.worker_queues = [Queue(1) for _ in self.worker_ids]
  138. def forward(self, x):
  139. # resize the input to (B, C, -1)
  140. input_shape = x.size()
  141. x = x.view(input_shape[0], self.num_features, -1)
  142. if x.get_device() == self.devices[0]:
  143. # Master mode
  144. extra = {
  145. "is_master": True,
  146. "master_queue": self.master_queue,
  147. "worker_queues": self.worker_queues,
  148. "worker_ids": self.worker_ids
  149. }
  150. else:
  151. # Worker mode
  152. extra = {
  153. "is_master": False,
  154. "master_queue": self.master_queue,
  155. "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())]
  156. }
  157. return syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
  158. extra, self.sync, self.training, self.momentum, self.eps,
  159. self.activation, self.slope).view(input_shape)
  160. def extra_repr(self):
  161. if self.activation == 'none':
  162. return 'sync={}'.format(self.sync)
  163. else:
  164. return 'sync={}, act={}, slope={}'.format(
  165. self.sync, self.activation, self.slope)
  166. class BatchNorm1d(SyncBatchNorm):
  167. """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`."""
  168. def __init__(self, *args, **kwargs):
  169. warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}."
  170. .format('BatchNorm1d', SyncBatchNorm.__name__), DeprecationWarning)
  171. super(BatchNorm1d, self).__init__(*args, **kwargs)
  172. class BatchNorm2d(SyncBatchNorm):
  173. """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`."""
  174. def __init__(self, *args, **kwargs):
  175. warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}."
  176. .format('BatchNorm2d', SyncBatchNorm.__name__), DeprecationWarning)
  177. super(BatchNorm2d, self).__init__(*args, **kwargs)
  178. class BatchNorm3d(SyncBatchNorm):
  179. """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`."""
  180. def __init__(self, *args, **kwargs):
  181. warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}."
  182. .format('BatchNorm3d', SyncBatchNorm.__name__), DeprecationWarning)
  183. super(BatchNorm3d, self).__init__(*args, **kwargs)