You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

37 lines
2.2KB

  1. import torch.nn.functional as F
  2. import torch.nn as nn
  3. import torch
  4. class CombinationModule(nn.Module):
  5. def __init__(self, c_low, c_up, batch_norm=False, group_norm=False, instance_norm=False):
  6. super(CombinationModule, self).__init__()
  7. if batch_norm:
  8. self.up = nn.Sequential(nn.Conv2d(c_low, c_up, kernel_size=3, padding=1, stride=1),
  9. nn.BatchNorm2d(c_up),
  10. nn.ReLU(inplace=True))
  11. self.cat_conv = nn.Sequential(nn.Conv2d(c_up*2, c_up, kernel_size=1, stride=1),
  12. nn.BatchNorm2d(c_up),
  13. nn.ReLU(inplace=True))
  14. elif group_norm:
  15. self.up = nn.Sequential(nn.Conv2d(c_low, c_up, kernel_size=3, padding=1, stride=1),
  16. nn.GroupNorm(num_groups=32, num_channels=c_up),
  17. nn.ReLU(inplace=True))
  18. self.cat_conv = nn.Sequential(nn.Conv2d(c_up * 2, c_up, kernel_size=1, stride=1),
  19. nn.GroupNorm(num_groups=32, num_channels=c_up),
  20. nn.ReLU(inplace=True))
  21. elif instance_norm:
  22. self.up = nn.Sequential(nn.Conv2d(c_low, c_up, kernel_size=3, padding=1, stride=1),
  23. nn.InstanceNorm2d(num_features=c_up),
  24. nn.ReLU(inplace=True))
  25. self.cat_conv = nn.Sequential(nn.Conv2d(c_up * 2, c_up, kernel_size=1, stride=1),
  26. nn.InstanceNorm2d(num_features=c_up),
  27. nn.ReLU(inplace=True))
  28. else:
  29. self.up = nn.Sequential(nn.Conv2d(c_low, c_up, kernel_size=3, padding=1, stride=1),
  30. nn.ReLU(inplace=True))
  31. self.cat_conv = nn.Sequential(nn.Conv2d(c_up*2, c_up, kernel_size=1, stride=1),
  32. nn.ReLU(inplace=True))
  33. def forward(self, x_low, x_up):
  34. x_low = self.up(F.interpolate(x_low, x_up.shape[2:], mode='bilinear', align_corners=False))
  35. return self.cat_conv(torch.cat((x_up, x_low), 1))