AIlib2/segutils/core/nn/psa_block.py

72 lines
1.4 KiB
Python

import torch
import torch.nn as nn
from torch.autograd.function import once_differentiable
#from core.nn import _C
__all__ = ['CollectAttention', 'DistributeAttention', 'psa_collect', 'psa_distribute']
class _PSACollect(torch.autograd.Function):
@staticmethod
def forward(ctx, hc):
out = _C.psa_forward(hc, 1)
ctx.save_for_backward(hc)
return out
@staticmethod
@once_differentiable
def backward(ctx, dout):
hc = ctx.saved_tensors
dhc = _C.psa_backward(dout, hc[0], 1)
return dhc
class _PSADistribute(torch.autograd.Function):
@staticmethod
def forward(ctx, hc):
out = _C.psa_forward(hc, 2)
ctx.save_for_backward(hc)
return out
@staticmethod
@once_differentiable
def backward(ctx, dout):
hc = ctx.saved_tensors
dhc = _C.psa_backward(dout, hc[0], 2)
return dhc
psa_collect = _PSACollect.apply
psa_distribute = _PSADistribute.apply
class CollectAttention(nn.Module):
"""Collect Attention Generation Module"""
def __init__(self):
super(CollectAttention, self).__init__()
def forward(self, x):
out = psa_collect(x)
return out
class DistributeAttention(nn.Module):
"""Distribute Attention Generation Module"""
def __init__(self):
super(DistributeAttention, self).__init__()
def forward(self, x):
out = psa_distribute(x)
return out