AIlib2/segutils/core/nn/syncbn.py

224 lines
8.7 KiB
Python

# Adopt from https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/nn/syncbn.py
"""Synchronized Cross-GPU Batch Normalization Module"""
import warnings
import torch
import torch.cuda.comm as comm
from queue import Queue
from torch.autograd import Function
from torch.nn.modules.batchnorm import _BatchNorm
from torch.autograd.function import once_differentiable
from core.nn import _C
__all__ = ['SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']
class _SyncBatchNorm(Function):
@classmethod
def forward(cls, ctx, x, gamma, beta, running_mean, running_var,
extra, sync=True, training=True, momentum=0.1, eps=1e-05,
activation="none", slope=0.01):
# save context
cls._parse_extra(ctx, extra)
ctx.sync = sync
ctx.training = training
ctx.momentum = momentum
ctx.eps = eps
ctx.activation = activation
ctx.slope = slope
assert activation == 'none'
# continous inputs
x = x.contiguous()
gamma = gamma.contiguous()
beta = beta.contiguous()
if ctx.training:
_ex, _exs = _C.expectation_forward(x)
if ctx.sync:
if ctx.is_master:
_ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
for _ in range(ctx.master_queue.maxsize):
_ex_w, _exs_w = ctx.master_queue.get()
ctx.master_queue.task_done()
_ex.append(_ex_w.unsqueeze(0))
_exs.append(_exs_w.unsqueeze(0))
_ex = comm.gather(_ex).mean(0)
_exs = comm.gather(_exs).mean(0)
tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
for ts, queue in zip(tensors[1:], ctx.worker_queues):
queue.put(ts)
else:
ctx.master_queue.put((_ex, _exs))
_ex, _exs = ctx.worker_queue.get()
ctx.worker_queue.task_done()
# Update running stats
_var = _exs - _ex ** 2
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)
# Mark in-place modified tensors
ctx.mark_dirty(running_mean, running_var)
else:
_ex, _var = running_mean.contiguous(), running_var.contiguous()
_exs = _var + _ex ** 2
# BN forward
y = _C.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
# Output
ctx.save_for_backward(x, _ex, _exs, gamma, beta)
return y
@staticmethod
@once_differentiable
def backward(ctx, dz):
x, _ex, _exs, gamma, beta = ctx.saved_tensors
dz = dz.contiguous()
# BN backward
dx, _dex, _dexs, dgamma, dbeta = _C.batchnorm_backward(dz, x, _ex, _exs, gamma, beta, ctx.eps)
if ctx.training:
if ctx.sync:
if ctx.is_master:
_dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
for _ in range(ctx.master_queue.maxsize):
_dex_w, _dexs_w = ctx.master_queue.get()
ctx.master_queue.task_done()
_dex.append(_dex_w.unsqueeze(0))
_dexs.append(_dexs_w.unsqueeze(0))
_dex = comm.gather(_dex).mean(0)
_dexs = comm.gather(_dexs).mean(0)
tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
for ts, queue in zip(tensors[1:], ctx.worker_queues):
queue.put(ts)
else:
ctx.master_queue.put((_dex, _dexs))
_dex, _dexs = ctx.worker_queue.get()
ctx.worker_queue.task_done()
dx_ = _C.expectation_backward(x, _dex, _dexs)
dx = dx + dx_
return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
@staticmethod
def _parse_extra(ctx, extra):
ctx.is_master = extra["is_master"]
if ctx.is_master:
ctx.master_queue = extra["master_queue"]
ctx.worker_queues = extra["worker_queues"]
ctx.worker_ids = extra["worker_ids"]
else:
ctx.master_queue = extra["master_queue"]
ctx.worker_queue = extra["worker_queue"]
syncbatchnorm = _SyncBatchNorm.apply
class SyncBatchNorm(_BatchNorm):
"""Cross-GPU Synchronized Batch normalization (SyncBN)
Parameters:
num_features: num_features from an expected input of
size batch_size x num_features x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
sync: a boolean value that when set to ``True``, synchronize across
different gpus. Default: ``True``
activation : str
Name of the activation functions, one of: `leaky_relu` or `none`.
slope : float
Negative slope for the `leaky_relu` activation.
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Reference:
.. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015*
.. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
Examples:
>>> m = SyncBatchNorm(100)
>>> net = torch.nn.DataParallel(m)
>>> output = net(input)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation='none', slope=0.01):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True)
self.activation = activation
self.slope = slope
self.devices = list(range(torch.cuda.device_count()))
self.sync = sync if len(self.devices) > 1 else False
# Initialize queues
self.worker_ids = self.devices[1:]
self.master_queue = Queue(len(self.worker_ids))
self.worker_queues = [Queue(1) for _ in self.worker_ids]
def forward(self, x):
# resize the input to (B, C, -1)
input_shape = x.size()
x = x.view(input_shape[0], self.num_features, -1)
if x.get_device() == self.devices[0]:
# Master mode
extra = {
"is_master": True,
"master_queue": self.master_queue,
"worker_queues": self.worker_queues,
"worker_ids": self.worker_ids
}
else:
# Worker mode
extra = {
"is_master": False,
"master_queue": self.master_queue,
"worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())]
}
return syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
extra, self.sync, self.training, self.momentum, self.eps,
self.activation, self.slope).view(input_shape)
def extra_repr(self):
if self.activation == 'none':
return 'sync={}'.format(self.sync)
else:
return 'sync={}, act={}, slope={}'.format(
self.sync, self.activation, self.slope)
class BatchNorm1d(SyncBatchNorm):
"""BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`."""
def __init__(self, *args, **kwargs):
warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}."
.format('BatchNorm1d', SyncBatchNorm.__name__), DeprecationWarning)
super(BatchNorm1d, self).__init__(*args, **kwargs)
class BatchNorm2d(SyncBatchNorm):
"""BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`."""
def __init__(self, *args, **kwargs):
warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}."
.format('BatchNorm2d', SyncBatchNorm.__name__), DeprecationWarning)
super(BatchNorm2d, self).__init__(*args, **kwargs)
class BatchNorm3d(SyncBatchNorm):
"""BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`."""
def __init__(self, *args, **kwargs):
warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}."
.format('BatchNorm3d', SyncBatchNorm.__name__), DeprecationWarning)
super(BatchNorm3d, self).__init__(*args, **kwargs)