用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

72 lines
1.4KB

  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd.function import once_differentiable
  4. #from core.nn import _C
  5. __all__ = ['CollectAttention', 'DistributeAttention', 'psa_collect', 'psa_distribute']
  6. class _PSACollect(torch.autograd.Function):
  7. @staticmethod
  8. def forward(ctx, hc):
  9. out = _C.psa_forward(hc, 1)
  10. ctx.save_for_backward(hc)
  11. return out
  12. @staticmethod
  13. @once_differentiable
  14. def backward(ctx, dout):
  15. hc = ctx.saved_tensors
  16. dhc = _C.psa_backward(dout, hc[0], 1)
  17. return dhc
  18. class _PSADistribute(torch.autograd.Function):
  19. @staticmethod
  20. def forward(ctx, hc):
  21. out = _C.psa_forward(hc, 2)
  22. ctx.save_for_backward(hc)
  23. return out
  24. @staticmethod
  25. @once_differentiable
  26. def backward(ctx, dout):
  27. hc = ctx.saved_tensors
  28. dhc = _C.psa_backward(dout, hc[0], 2)
  29. return dhc
  30. psa_collect = _PSACollect.apply
  31. psa_distribute = _PSADistribute.apply
  32. class CollectAttention(nn.Module):
  33. """Collect Attention Generation Module"""
  34. def __init__(self):
  35. super(CollectAttention, self).__init__()
  36. def forward(self, x):
  37. out = psa_collect(x)
  38. return out
  39. class DistributeAttention(nn.Module):
  40. """Distribute Attention Generation Module"""
  41. def __init__(self):
  42. super(DistributeAttention, self).__init__()
  43. def forward(self, x):
  44. out = psa_distribute(x)
  45. return out