# Adopt from https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/nn/syncbn.py """Synchronized Cross-GPU Batch Normalization Module""" import warnings import torch import torch.cuda.comm as comm from queue import Queue from torch.autograd import Function from torch.nn.modules.batchnorm import _BatchNorm from torch.autograd.function import once_differentiable from core.nn import _C __all__ = ['SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d'] class _SyncBatchNorm(Function): @classmethod def forward(cls, ctx, x, gamma, beta, running_mean, running_var, extra, sync=True, training=True, momentum=0.1, eps=1e-05, activation="none", slope=0.01): # save context cls._parse_extra(ctx, extra) ctx.sync = sync ctx.training = training ctx.momentum = momentum ctx.eps = eps ctx.activation = activation ctx.slope = slope assert activation == 'none' # continous inputs x = x.contiguous() gamma = gamma.contiguous() beta = beta.contiguous() if ctx.training: _ex, _exs = _C.expectation_forward(x) if ctx.sync: if ctx.is_master: _ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)] for _ in range(ctx.master_queue.maxsize): _ex_w, _exs_w = ctx.master_queue.get() ctx.master_queue.task_done() _ex.append(_ex_w.unsqueeze(0)) _exs.append(_exs_w.unsqueeze(0)) _ex = comm.gather(_ex).mean(0) _exs = comm.gather(_exs).mean(0) tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: ctx.master_queue.put((_ex, _exs)) _ex, _exs = ctx.worker_queue.get() ctx.worker_queue.task_done() # Update running stats _var = _exs - _ex ** 2 running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex) running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var) # Mark in-place modified tensors ctx.mark_dirty(running_mean, running_var) else: _ex, _var = running_mean.contiguous(), running_var.contiguous() _exs = _var + _ex ** 2 # BN forward y = _C.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps) # Output ctx.save_for_backward(x, _ex, _exs, gamma, beta) return y @staticmethod @once_differentiable def backward(ctx, dz): x, _ex, _exs, gamma, beta = ctx.saved_tensors dz = dz.contiguous() # BN backward dx, _dex, _dexs, dgamma, dbeta = _C.batchnorm_backward(dz, x, _ex, _exs, gamma, beta, ctx.eps) if ctx.training: if ctx.sync: if ctx.is_master: _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)] for _ in range(ctx.master_queue.maxsize): _dex_w, _dexs_w = ctx.master_queue.get() ctx.master_queue.task_done() _dex.append(_dex_w.unsqueeze(0)) _dexs.append(_dexs_w.unsqueeze(0)) _dex = comm.gather(_dex).mean(0) _dexs = comm.gather(_dexs).mean(0) tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: ctx.master_queue.put((_dex, _dexs)) _dex, _dexs = ctx.worker_queue.get() ctx.worker_queue.task_done() dx_ = _C.expectation_backward(x, _dex, _dexs) dx = dx + dx_ return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None @staticmethod def _parse_extra(ctx, extra): ctx.is_master = extra["is_master"] if ctx.is_master: ctx.master_queue = extra["master_queue"] ctx.worker_queues = extra["worker_queues"] ctx.worker_ids = extra["worker_ids"] else: ctx.master_queue = extra["master_queue"] ctx.worker_queue = extra["worker_queue"] syncbatchnorm = _SyncBatchNorm.apply class SyncBatchNorm(_BatchNorm): """Cross-GPU Synchronized Batch normalization (SyncBN) Parameters: num_features: num_features from an expected input of size batch_size x num_features x height x width eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 sync: a boolean value that when set to ``True``, synchronize across different gpus. Default: ``True`` activation : str Name of the activation functions, one of: `leaky_relu` or `none`. slope : float Negative slope for the `leaky_relu` activation. Shape: - Input: :math:`(N, C, H, W)` - Output: :math:`(N, C, H, W)` (same shape as input) Reference: .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015* .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* Examples: >>> m = SyncBatchNorm(100) >>> net = torch.nn.DataParallel(m) >>> output = net(input) """ def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation='none', slope=0.01): super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True) self.activation = activation self.slope = slope self.devices = list(range(torch.cuda.device_count())) self.sync = sync if len(self.devices) > 1 else False # Initialize queues self.worker_ids = self.devices[1:] self.master_queue = Queue(len(self.worker_ids)) self.worker_queues = [Queue(1) for _ in self.worker_ids] def forward(self, x): # resize the input to (B, C, -1) input_shape = x.size() x = x.view(input_shape[0], self.num_features, -1) if x.get_device() == self.devices[0]: # Master mode extra = { "is_master": True, "master_queue": self.master_queue, "worker_queues": self.worker_queues, "worker_ids": self.worker_ids } else: # Worker mode extra = { "is_master": False, "master_queue": self.master_queue, "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())] } return syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var, extra, self.sync, self.training, self.momentum, self.eps, self.activation, self.slope).view(input_shape) def extra_repr(self): if self.activation == 'none': return 'sync={}'.format(self.sync) else: return 'sync={}, act={}, slope={}'.format( self.sync, self.activation, self.slope) class BatchNorm1d(SyncBatchNorm): """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" def __init__(self, *args, **kwargs): warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." .format('BatchNorm1d', SyncBatchNorm.__name__), DeprecationWarning) super(BatchNorm1d, self).__init__(*args, **kwargs) class BatchNorm2d(SyncBatchNorm): """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" def __init__(self, *args, **kwargs): warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." .format('BatchNorm2d', SyncBatchNorm.__name__), DeprecationWarning) super(BatchNorm2d, self).__init__(*args, **kwargs) class BatchNorm3d(SyncBatchNorm): """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" def __init__(self, *args, **kwargs): warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." .format('BatchNorm3d', SyncBatchNorm.__name__), DeprecationWarning) super(BatchNorm3d, self).__init__(*args, **kwargs)