AIlib2/segutils/core/nn/ca_block.py

73 lines
1.7 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import once_differentiable
#from core.nn import _C
__all__ = ['CrissCrossAttention', 'ca_weight', 'ca_map']
class _CAWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, t, f):
weight = _C.ca_forward(t, f)
ctx.save_for_backward(t, f)
return weight
@staticmethod
@once_differentiable
def backward(ctx, dw):
t, f = ctx.saved_tensors
dt, df = _C.ca_backward(dw, t, f)
return dt, df
class _CAMap(torch.autograd.Function):
@staticmethod
def forward(ctx, weight, g):
out = _C.ca_map_forward(weight, g)
ctx.save_for_backward(weight, g)
return out
@staticmethod
@once_differentiable
def backward(ctx, dout):
weight, g = ctx.saved_tensors
dw, dg = _C.ca_map_backward(dout, weight, g)
return dw, dg
ca_weight = _CAWeight.apply
ca_map = _CAMap.apply
class CrissCrossAttention(nn.Module):
"""Criss-Cross Attention Module"""
def __init__(self, in_channels):
super(CrissCrossAttention, self).__init__()
self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
proj_query = self.query_conv(x)
proj_key = self.key_conv(x)
proj_value = self.value_conv(x)
energy = ca_weight(proj_query, proj_key)
attention = F.softmax(energy, 1)
out = ca_map(attention, proj_value)
out = self.gamma * out + x
return out