|
- # Adopt from https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/nn/syncbn.py
- """Synchronized Cross-GPU Batch Normalization Module"""
- import warnings
- import torch
- import torch.cuda.comm as comm
-
- from queue import Queue
- from torch.autograd import Function
- from torch.nn.modules.batchnorm import _BatchNorm
- from torch.autograd.function import once_differentiable
- from core.nn import _C
-
- __all__ = ['SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']
-
-
- class _SyncBatchNorm(Function):
- @classmethod
- def forward(cls, ctx, x, gamma, beta, running_mean, running_var,
- extra, sync=True, training=True, momentum=0.1, eps=1e-05,
- activation="none", slope=0.01):
- # save context
- cls._parse_extra(ctx, extra)
- ctx.sync = sync
- ctx.training = training
- ctx.momentum = momentum
- ctx.eps = eps
- ctx.activation = activation
- ctx.slope = slope
- assert activation == 'none'
-
- # continous inputs
- x = x.contiguous()
- gamma = gamma.contiguous()
- beta = beta.contiguous()
-
- if ctx.training:
- _ex, _exs = _C.expectation_forward(x)
-
- if ctx.sync:
- if ctx.is_master:
- _ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
- for _ in range(ctx.master_queue.maxsize):
- _ex_w, _exs_w = ctx.master_queue.get()
- ctx.master_queue.task_done()
- _ex.append(_ex_w.unsqueeze(0))
- _exs.append(_exs_w.unsqueeze(0))
-
- _ex = comm.gather(_ex).mean(0)
- _exs = comm.gather(_exs).mean(0)
-
- tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
- for ts, queue in zip(tensors[1:], ctx.worker_queues):
- queue.put(ts)
- else:
- ctx.master_queue.put((_ex, _exs))
- _ex, _exs = ctx.worker_queue.get()
- ctx.worker_queue.task_done()
-
- # Update running stats
- _var = _exs - _ex ** 2
- running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
- running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)
-
- # Mark in-place modified tensors
- ctx.mark_dirty(running_mean, running_var)
- else:
- _ex, _var = running_mean.contiguous(), running_var.contiguous()
- _exs = _var + _ex ** 2
-
- # BN forward
- y = _C.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
-
- # Output
- ctx.save_for_backward(x, _ex, _exs, gamma, beta)
- return y
-
- @staticmethod
- @once_differentiable
- def backward(ctx, dz):
- x, _ex, _exs, gamma, beta = ctx.saved_tensors
- dz = dz.contiguous()
-
- # BN backward
- dx, _dex, _dexs, dgamma, dbeta = _C.batchnorm_backward(dz, x, _ex, _exs, gamma, beta, ctx.eps)
-
- if ctx.training:
- if ctx.sync:
- if ctx.is_master:
- _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
- for _ in range(ctx.master_queue.maxsize):
- _dex_w, _dexs_w = ctx.master_queue.get()
- ctx.master_queue.task_done()
- _dex.append(_dex_w.unsqueeze(0))
- _dexs.append(_dexs_w.unsqueeze(0))
-
- _dex = comm.gather(_dex).mean(0)
- _dexs = comm.gather(_dexs).mean(0)
-
- tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
- for ts, queue in zip(tensors[1:], ctx.worker_queues):
- queue.put(ts)
- else:
- ctx.master_queue.put((_dex, _dexs))
- _dex, _dexs = ctx.worker_queue.get()
- ctx.worker_queue.task_done()
-
- dx_ = _C.expectation_backward(x, _dex, _dexs)
- dx = dx + dx_
-
- return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
-
- @staticmethod
- def _parse_extra(ctx, extra):
- ctx.is_master = extra["is_master"]
- if ctx.is_master:
- ctx.master_queue = extra["master_queue"]
- ctx.worker_queues = extra["worker_queues"]
- ctx.worker_ids = extra["worker_ids"]
- else:
- ctx.master_queue = extra["master_queue"]
- ctx.worker_queue = extra["worker_queue"]
-
-
- syncbatchnorm = _SyncBatchNorm.apply
-
-
- class SyncBatchNorm(_BatchNorm):
- """Cross-GPU Synchronized Batch normalization (SyncBN)
-
- Parameters:
- num_features: num_features from an expected input of
- size batch_size x num_features x height x width
- eps: a value added to the denominator for numerical stability.
- Default: 1e-5
- momentum: the value used for the running_mean and running_var
- computation. Default: 0.1
- sync: a boolean value that when set to ``True``, synchronize across
- different gpus. Default: ``True``
- activation : str
- Name of the activation functions, one of: `leaky_relu` or `none`.
- slope : float
- Negative slope for the `leaky_relu` activation.
-
- Shape:
- - Input: :math:`(N, C, H, W)`
- - Output: :math:`(N, C, H, W)` (same shape as input)
- Reference:
- .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015*
- .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
- Examples:
- >>> m = SyncBatchNorm(100)
- >>> net = torch.nn.DataParallel(m)
- >>> output = net(input)
- """
-
- def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation='none', slope=0.01):
- super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True)
- self.activation = activation
- self.slope = slope
- self.devices = list(range(torch.cuda.device_count()))
- self.sync = sync if len(self.devices) > 1 else False
- # Initialize queues
- self.worker_ids = self.devices[1:]
- self.master_queue = Queue(len(self.worker_ids))
- self.worker_queues = [Queue(1) for _ in self.worker_ids]
-
- def forward(self, x):
- # resize the input to (B, C, -1)
- input_shape = x.size()
- x = x.view(input_shape[0], self.num_features, -1)
- if x.get_device() == self.devices[0]:
- # Master mode
- extra = {
- "is_master": True,
- "master_queue": self.master_queue,
- "worker_queues": self.worker_queues,
- "worker_ids": self.worker_ids
- }
- else:
- # Worker mode
- extra = {
- "is_master": False,
- "master_queue": self.master_queue,
- "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())]
- }
-
- return syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
- extra, self.sync, self.training, self.momentum, self.eps,
- self.activation, self.slope).view(input_shape)
-
- def extra_repr(self):
- if self.activation == 'none':
- return 'sync={}'.format(self.sync)
- else:
- return 'sync={}, act={}, slope={}'.format(
- self.sync, self.activation, self.slope)
-
-
- class BatchNorm1d(SyncBatchNorm):
- """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`."""
-
- def __init__(self, *args, **kwargs):
- warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}."
- .format('BatchNorm1d', SyncBatchNorm.__name__), DeprecationWarning)
- super(BatchNorm1d, self).__init__(*args, **kwargs)
-
-
- class BatchNorm2d(SyncBatchNorm):
- """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`."""
-
- def __init__(self, *args, **kwargs):
- warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}."
- .format('BatchNorm2d', SyncBatchNorm.__name__), DeprecationWarning)
- super(BatchNorm2d, self).__init__(*args, **kwargs)
-
-
- class BatchNorm3d(SyncBatchNorm):
- """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`."""
-
- def __init__(self, *args, **kwargs):
- warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}."
- .format('BatchNorm3d', SyncBatchNorm.__name__), DeprecationWarning)
- super(BatchNorm3d, self).__init__(*args, **kwargs)
|