286 lines
10 KiB
Python
286 lines
10 KiB
Python
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
## Created by: Hang Zhang
|
|
## Email: zhanghang0704@gmail.com
|
|
## Copyright (c) 2018
|
|
##
|
|
## This source code is licensed under the MIT-style license found in the
|
|
## LICENSE file in the root directory of this source tree
|
|
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
|
|
"""Synchronized Cross-GPU Batch Normalization functions"""
|
|
import torch.cuda.comm as comm
|
|
|
|
from torch.autograd import Function
|
|
from torch.autograd.function import once_differentiable
|
|
from core.nn.sync_bn import lib
|
|
|
|
__all__ = ['syncbatchnorm', 'inp_syncbatchnorm']
|
|
|
|
|
|
class syncbatchnorm_(Function):
|
|
@classmethod
|
|
def forward(cls, ctx, x, gamma, beta, running_mean, running_var,
|
|
extra, sync=True, training=True, momentum=0.1, eps=1e-05,
|
|
activation="none", slope=0.01):
|
|
# save context
|
|
cls._parse_extra(ctx, extra)
|
|
ctx.sync = sync
|
|
ctx.training = training
|
|
ctx.momentum = momentum
|
|
ctx.eps = eps
|
|
ctx.activation = activation
|
|
ctx.slope = slope
|
|
assert activation == 'none'
|
|
|
|
# continous inputs
|
|
x = x.contiguous()
|
|
gamma = gamma.contiguous()
|
|
beta = beta.contiguous()
|
|
|
|
if ctx.training:
|
|
if x.is_cuda:
|
|
_ex, _exs = lib.gpu.expectation_forward(x)
|
|
else:
|
|
raise NotImplemented
|
|
|
|
if ctx.sync:
|
|
if ctx.is_master:
|
|
_ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
|
|
for _ in range(ctx.master_queue.maxsize):
|
|
_ex_w, _exs_w = ctx.master_queue.get()
|
|
ctx.master_queue.task_done()
|
|
_ex.append(_ex_w.unsqueeze(0))
|
|
_exs.append(_exs_w.unsqueeze(0))
|
|
|
|
_ex = comm.gather(_ex).mean(0)
|
|
_exs = comm.gather(_exs).mean(0)
|
|
|
|
tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
|
|
for ts, queue in zip(tensors[1:], ctx.worker_queues):
|
|
queue.put(ts)
|
|
else:
|
|
ctx.master_queue.put((_ex, _exs))
|
|
_ex, _exs = ctx.worker_queue.get()
|
|
ctx.worker_queue.task_done()
|
|
|
|
# Update running stats
|
|
_var = _exs - _ex ** 2
|
|
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
|
|
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)
|
|
|
|
# Mark in-place modified tensors
|
|
ctx.mark_dirty(running_mean, running_var)
|
|
else:
|
|
_ex, _var = running_mean.contiguous(), running_var.contiguous()
|
|
_exs = _var + _ex ** 2
|
|
|
|
# BN forward
|
|
if x.is_cuda:
|
|
y = lib.gpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
|
|
else:
|
|
y = lib.cpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
|
|
|
|
# Output
|
|
ctx.save_for_backward(x, _ex, _exs, gamma, beta)
|
|
return y
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, dz):
|
|
x, _ex, _exs, gamma, beta = ctx.saved_tensors
|
|
dz = dz.contiguous()
|
|
|
|
# BN backward
|
|
if dz.is_cuda:
|
|
dx, _dex, _dexs, dgamma, dbeta = lib.gpu.batchnorm_backward(dz, x, _ex, _exs, gamma, beta, ctx.eps)
|
|
else:
|
|
raise NotImplemented
|
|
|
|
if ctx.training:
|
|
if ctx.sync:
|
|
if ctx.is_master:
|
|
_dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
|
|
for _ in range(ctx.master_queue.maxsize):
|
|
_dex_w, _dexs_w = ctx.master_queue.get()
|
|
ctx.master_queue.task_done()
|
|
_dex.append(_dex_w.unsqueeze(0))
|
|
_dexs.append(_dexs_w.unsqueeze(0))
|
|
|
|
_dex = comm.gather(_dex).mean(0)
|
|
_dexs = comm.gather(_dexs).mean(0)
|
|
|
|
tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
|
|
for ts, queue in zip(tensors[1:], ctx.worker_queues):
|
|
queue.put(ts)
|
|
else:
|
|
ctx.master_queue.put((_dex, _dexs))
|
|
_dex, _dexs = ctx.worker_queue.get()
|
|
ctx.worker_queue.task_done()
|
|
|
|
if x.is_cuda:
|
|
dx_ = lib.gpu.expectation_backward(x, _dex, _dexs)
|
|
else:
|
|
raise NotImplemented
|
|
dx = dx + dx_
|
|
|
|
return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
|
|
|
|
@staticmethod
|
|
def _parse_extra(ctx, extra):
|
|
ctx.is_master = extra["is_master"]
|
|
if ctx.is_master:
|
|
ctx.master_queue = extra["master_queue"]
|
|
ctx.worker_queues = extra["worker_queues"]
|
|
ctx.worker_ids = extra["worker_ids"]
|
|
else:
|
|
ctx.master_queue = extra["master_queue"]
|
|
ctx.worker_queue = extra["worker_queue"]
|
|
|
|
|
|
def _act_forward(ctx, x):
|
|
if ctx.activation.lower() == "leaky_relu":
|
|
if x.is_cuda:
|
|
lib.gpu.leaky_relu_forward(x, ctx.slope)
|
|
else:
|
|
raise NotImplemented
|
|
else:
|
|
assert ctx.activation == 'none'
|
|
|
|
|
|
def _act_backward(ctx, x, dx):
|
|
if ctx.activation.lower() == "leaky_relu":
|
|
if x.is_cuda:
|
|
lib.gpu.leaky_relu_backward(x, dx, ctx.slope)
|
|
else:
|
|
raise NotImplemented
|
|
else:
|
|
assert ctx.activation == 'none'
|
|
|
|
|
|
class inp_syncbatchnorm_(Function):
|
|
@classmethod
|
|
def forward(cls, ctx, x, gamma, beta, running_mean, running_var,
|
|
extra, sync=True, training=True, momentum=0.1, eps=1e-5,
|
|
activation='none', slope=0.01):
|
|
# save context
|
|
cls._parse_extra(ctx, extra)
|
|
ctx.sync = sync
|
|
ctx.training = training
|
|
ctx.momentum = momentum
|
|
ctx.eps = eps
|
|
ctx.activation = activation
|
|
ctx.slope = slope
|
|
|
|
# continous inputs
|
|
x = x.contiguous()
|
|
gamma = gamma.contiguous()
|
|
beta = beta.contiguous()
|
|
|
|
if ctx.training:
|
|
if x.is_cuda:
|
|
_ex, _exs = lib.gpu.expectation_forward(x)
|
|
else:
|
|
raise NotImplemented
|
|
|
|
if ctx.sync:
|
|
if ctx.is_master:
|
|
_ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
|
|
for _ in range(ctx.master_queue.maxsize):
|
|
_ex_w, _exs_w = ctx.master_queue.get()
|
|
ctx.master_queue.task_done()
|
|
_ex.append(_ex_w.unsqueeze(0))
|
|
_exs.append(_exs_w.unsuqeeze(0))
|
|
|
|
_ex = comm.gather(_ex).mean(0)
|
|
_exs = comm.gather(_exs).mean(0)
|
|
|
|
tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
|
|
for ts, queue in zip(tensors[1:], ctx.worker_queues):
|
|
queue.put(ts)
|
|
else:
|
|
ctx.master_queue.put((_ex, _exs))
|
|
_ex, _exs = ctx.worker_queue.get()
|
|
ctx.worker_queue.task_done()
|
|
|
|
# Update running stats
|
|
_var = _exs - _ex ** 2
|
|
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
|
|
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)
|
|
|
|
# Mark in-place modified tensors
|
|
ctx.mark_dirty(x, running_mean, running_var)
|
|
else:
|
|
_ex, _var = running_mean.contiguous(), running_var.contiguous()
|
|
_exs = _var + _ex ** 2
|
|
ctx.mark_dirty(x)
|
|
|
|
# BN forward + activation
|
|
if x.is_cuda:
|
|
lib.gpu.batchnorm_inp_forward(x, _ex, _exs, gamma, beta, ctx.eps)
|
|
else:
|
|
raise NotImplemented
|
|
|
|
_act_forward(ctx, x)
|
|
|
|
# Output
|
|
ctx.save_for_backward(x, _ex, _exs, gamma, beta)
|
|
return x
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, dz):
|
|
z, _ex, _exs, gamma, beta = ctx.saved_tensors
|
|
dz = dz.contiguous()
|
|
|
|
# Undo activation
|
|
_act_backward(ctx, z, dz)
|
|
|
|
# BN backward
|
|
if dz.is_cuda:
|
|
dx, _dex, _dexs, dgamma, dbeta = lib.gpu.batchnorm_inp_backward(dz, z, _ex, _exs, gamma, beta, ctx.eps)
|
|
else:
|
|
raise NotImplemented
|
|
|
|
if ctx.training:
|
|
if ctx.sync:
|
|
if ctx.is_master:
|
|
_dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
|
|
for _ in range(ctx.master_queue.maxsize):
|
|
_dex_w, _dexs_w = ctx.master_queue.get()
|
|
ctx.master_queue.task_done()
|
|
_dex.append(_dex_w.unsqueeze(0))
|
|
_dexs.append(_dexs_w.unsqueeze(0))
|
|
|
|
_dex = comm.gather(_dex).mean(0)
|
|
_dexs = comm.gather(_dexs).mean(0)
|
|
|
|
tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
|
|
for ts, queue in zip(tensors[1:], ctx.worker_queues):
|
|
queue.put(ts)
|
|
else:
|
|
ctx.master_queue.put((_dex, _dexs))
|
|
_dex, _dexs = ctx.worker_queue.get()
|
|
ctx.worker_queue.task_done()
|
|
|
|
if z.is_cuda:
|
|
lib.gpu.expectation_inp_backward(dx, z, _dex, _dexs, _ex, _exs, gamma, beta, ctx.eps)
|
|
else:
|
|
raise NotImplemented
|
|
|
|
return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
|
|
|
|
@staticmethod
|
|
def _parse_extra(ctx, extra):
|
|
ctx.is_master = extra["is_master"]
|
|
if ctx.is_master:
|
|
ctx.master_queue = extra["master_queue"]
|
|
ctx.worker_queues = extra["worker_queues"]
|
|
ctx.worker_ids = extra["worker_ids"]
|
|
else:
|
|
ctx.master_queue = extra["master_queue"]
|
|
ctx.worker_queue = extra["worker_queue"]
|
|
|
|
|
|
syncbatchnorm = syncbatchnorm_.apply
|
|
inp_syncbatchnorm = inp_syncbatchnorm_.apply
|