用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ca_block.py 1.7KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.autograd.function import once_differentiable
  5. #from core.nn import _C
  6. __all__ = ['CrissCrossAttention', 'ca_weight', 'ca_map']
  7. class _CAWeight(torch.autograd.Function):
  8. @staticmethod
  9. def forward(ctx, t, f):
  10. weight = _C.ca_forward(t, f)
  11. ctx.save_for_backward(t, f)
  12. return weight
  13. @staticmethod
  14. @once_differentiable
  15. def backward(ctx, dw):
  16. t, f = ctx.saved_tensors
  17. dt, df = _C.ca_backward(dw, t, f)
  18. return dt, df
  19. class _CAMap(torch.autograd.Function):
  20. @staticmethod
  21. def forward(ctx, weight, g):
  22. out = _C.ca_map_forward(weight, g)
  23. ctx.save_for_backward(weight, g)
  24. return out
  25. @staticmethod
  26. @once_differentiable
  27. def backward(ctx, dout):
  28. weight, g = ctx.saved_tensors
  29. dw, dg = _C.ca_map_backward(dout, weight, g)
  30. return dw, dg
  31. ca_weight = _CAWeight.apply
  32. ca_map = _CAMap.apply
  33. class CrissCrossAttention(nn.Module):
  34. """Criss-Cross Attention Module"""
  35. def __init__(self, in_channels):
  36. super(CrissCrossAttention, self).__init__()
  37. self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
  38. self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
  39. self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
  40. self.gamma = nn.Parameter(torch.zeros(1))
  41. def forward(self, x):
  42. proj_query = self.query_conv(x)
  43. proj_key = self.key_conv(x)
  44. proj_value = self.value_conv(x)
  45. energy = ca_weight(proj_query, proj_key)
  46. attention = F.softmax(energy, 1)
  47. out = ca_map(attention, proj_value)
  48. out = self.gamma * out + x
  49. return out