import torch import torch.nn as nn from torch.autograd.function import once_differentiable #from core.nn import _C __all__ = ['CollectAttention', 'DistributeAttention', 'psa_collect', 'psa_distribute'] class _PSACollect(torch.autograd.Function): @staticmethod def forward(ctx, hc): out = _C.psa_forward(hc, 1) ctx.save_for_backward(hc) return out @staticmethod @once_differentiable def backward(ctx, dout): hc = ctx.saved_tensors dhc = _C.psa_backward(dout, hc[0], 1) return dhc class _PSADistribute(torch.autograd.Function): @staticmethod def forward(ctx, hc): out = _C.psa_forward(hc, 2) ctx.save_for_backward(hc) return out @staticmethod @once_differentiable def backward(ctx, dout): hc = ctx.saved_tensors dhc = _C.psa_backward(dout, hc[0], 2) return dhc psa_collect = _PSACollect.apply psa_distribute = _PSADistribute.apply class CollectAttention(nn.Module): """Collect Attention Generation Module""" def __init__(self): super(CollectAttention, self).__init__() def forward(self, x): out = psa_collect(x) return out class DistributeAttention(nn.Module): """Distribute Attention Generation Module""" def __init__(self): super(DistributeAttention, self).__init__() def forward(self, x): out = psa_distribute(x) return out