224 lines
8.7 KiB
Python
224 lines
8.7 KiB
Python
# Adopt from https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/nn/syncbn.py
|
|
"""Synchronized Cross-GPU Batch Normalization Module"""
|
|
import warnings
|
|
import torch
|
|
import torch.cuda.comm as comm
|
|
|
|
from queue import Queue
|
|
from torch.autograd import Function
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
from torch.autograd.function import once_differentiable
|
|
from core.nn import _C
|
|
|
|
__all__ = ['SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']
|
|
|
|
|
|
class _SyncBatchNorm(Function):
|
|
@classmethod
|
|
def forward(cls, ctx, x, gamma, beta, running_mean, running_var,
|
|
extra, sync=True, training=True, momentum=0.1, eps=1e-05,
|
|
activation="none", slope=0.01):
|
|
# save context
|
|
cls._parse_extra(ctx, extra)
|
|
ctx.sync = sync
|
|
ctx.training = training
|
|
ctx.momentum = momentum
|
|
ctx.eps = eps
|
|
ctx.activation = activation
|
|
ctx.slope = slope
|
|
assert activation == 'none'
|
|
|
|
# continous inputs
|
|
x = x.contiguous()
|
|
gamma = gamma.contiguous()
|
|
beta = beta.contiguous()
|
|
|
|
if ctx.training:
|
|
_ex, _exs = _C.expectation_forward(x)
|
|
|
|
if ctx.sync:
|
|
if ctx.is_master:
|
|
_ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
|
|
for _ in range(ctx.master_queue.maxsize):
|
|
_ex_w, _exs_w = ctx.master_queue.get()
|
|
ctx.master_queue.task_done()
|
|
_ex.append(_ex_w.unsqueeze(0))
|
|
_exs.append(_exs_w.unsqueeze(0))
|
|
|
|
_ex = comm.gather(_ex).mean(0)
|
|
_exs = comm.gather(_exs).mean(0)
|
|
|
|
tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
|
|
for ts, queue in zip(tensors[1:], ctx.worker_queues):
|
|
queue.put(ts)
|
|
else:
|
|
ctx.master_queue.put((_ex, _exs))
|
|
_ex, _exs = ctx.worker_queue.get()
|
|
ctx.worker_queue.task_done()
|
|
|
|
# Update running stats
|
|
_var = _exs - _ex ** 2
|
|
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
|
|
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)
|
|
|
|
# Mark in-place modified tensors
|
|
ctx.mark_dirty(running_mean, running_var)
|
|
else:
|
|
_ex, _var = running_mean.contiguous(), running_var.contiguous()
|
|
_exs = _var + _ex ** 2
|
|
|
|
# BN forward
|
|
y = _C.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
|
|
|
|
# Output
|
|
ctx.save_for_backward(x, _ex, _exs, gamma, beta)
|
|
return y
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, dz):
|
|
x, _ex, _exs, gamma, beta = ctx.saved_tensors
|
|
dz = dz.contiguous()
|
|
|
|
# BN backward
|
|
dx, _dex, _dexs, dgamma, dbeta = _C.batchnorm_backward(dz, x, _ex, _exs, gamma, beta, ctx.eps)
|
|
|
|
if ctx.training:
|
|
if ctx.sync:
|
|
if ctx.is_master:
|
|
_dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
|
|
for _ in range(ctx.master_queue.maxsize):
|
|
_dex_w, _dexs_w = ctx.master_queue.get()
|
|
ctx.master_queue.task_done()
|
|
_dex.append(_dex_w.unsqueeze(0))
|
|
_dexs.append(_dexs_w.unsqueeze(0))
|
|
|
|
_dex = comm.gather(_dex).mean(0)
|
|
_dexs = comm.gather(_dexs).mean(0)
|
|
|
|
tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
|
|
for ts, queue in zip(tensors[1:], ctx.worker_queues):
|
|
queue.put(ts)
|
|
else:
|
|
ctx.master_queue.put((_dex, _dexs))
|
|
_dex, _dexs = ctx.worker_queue.get()
|
|
ctx.worker_queue.task_done()
|
|
|
|
dx_ = _C.expectation_backward(x, _dex, _dexs)
|
|
dx = dx + dx_
|
|
|
|
return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
|
|
|
|
@staticmethod
|
|
def _parse_extra(ctx, extra):
|
|
ctx.is_master = extra["is_master"]
|
|
if ctx.is_master:
|
|
ctx.master_queue = extra["master_queue"]
|
|
ctx.worker_queues = extra["worker_queues"]
|
|
ctx.worker_ids = extra["worker_ids"]
|
|
else:
|
|
ctx.master_queue = extra["master_queue"]
|
|
ctx.worker_queue = extra["worker_queue"]
|
|
|
|
|
|
syncbatchnorm = _SyncBatchNorm.apply
|
|
|
|
|
|
class SyncBatchNorm(_BatchNorm):
|
|
"""Cross-GPU Synchronized Batch normalization (SyncBN)
|
|
|
|
Parameters:
|
|
num_features: num_features from an expected input of
|
|
size batch_size x num_features x height x width
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Default: 0.1
|
|
sync: a boolean value that when set to ``True``, synchronize across
|
|
different gpus. Default: ``True``
|
|
activation : str
|
|
Name of the activation functions, one of: `leaky_relu` or `none`.
|
|
slope : float
|
|
Negative slope for the `leaky_relu` activation.
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H, W)`
|
|
- Output: :math:`(N, C, H, W)` (same shape as input)
|
|
Reference:
|
|
.. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015*
|
|
.. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
|
|
Examples:
|
|
>>> m = SyncBatchNorm(100)
|
|
>>> net = torch.nn.DataParallel(m)
|
|
>>> output = net(input)
|
|
"""
|
|
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation='none', slope=0.01):
|
|
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True)
|
|
self.activation = activation
|
|
self.slope = slope
|
|
self.devices = list(range(torch.cuda.device_count()))
|
|
self.sync = sync if len(self.devices) > 1 else False
|
|
# Initialize queues
|
|
self.worker_ids = self.devices[1:]
|
|
self.master_queue = Queue(len(self.worker_ids))
|
|
self.worker_queues = [Queue(1) for _ in self.worker_ids]
|
|
|
|
def forward(self, x):
|
|
# resize the input to (B, C, -1)
|
|
input_shape = x.size()
|
|
x = x.view(input_shape[0], self.num_features, -1)
|
|
if x.get_device() == self.devices[0]:
|
|
# Master mode
|
|
extra = {
|
|
"is_master": True,
|
|
"master_queue": self.master_queue,
|
|
"worker_queues": self.worker_queues,
|
|
"worker_ids": self.worker_ids
|
|
}
|
|
else:
|
|
# Worker mode
|
|
extra = {
|
|
"is_master": False,
|
|
"master_queue": self.master_queue,
|
|
"worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())]
|
|
}
|
|
|
|
return syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
|
|
extra, self.sync, self.training, self.momentum, self.eps,
|
|
self.activation, self.slope).view(input_shape)
|
|
|
|
def extra_repr(self):
|
|
if self.activation == 'none':
|
|
return 'sync={}'.format(self.sync)
|
|
else:
|
|
return 'sync={}, act={}, slope={}'.format(
|
|
self.sync, self.activation, self.slope)
|
|
|
|
|
|
class BatchNorm1d(SyncBatchNorm):
|
|
"""BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}."
|
|
.format('BatchNorm1d', SyncBatchNorm.__name__), DeprecationWarning)
|
|
super(BatchNorm1d, self).__init__(*args, **kwargs)
|
|
|
|
|
|
class BatchNorm2d(SyncBatchNorm):
|
|
"""BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}."
|
|
.format('BatchNorm2d', SyncBatchNorm.__name__), DeprecationWarning)
|
|
super(BatchNorm2d, self).__init__(*args, **kwargs)
|
|
|
|
|
|
class BatchNorm3d(SyncBatchNorm):
|
|
"""BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}."
|
|
.format('BatchNorm3d', SyncBatchNorm.__name__), DeprecationWarning)
|
|
super(BatchNorm3d, self).__init__(*args, **kwargs)
|