|
- import torch
- from torch import nn
- import torch.nn.functional as F
- import core.lib.psa.functional as PF
- import modeling.backbone.resnet_real as models
-
- #运行失败,compact可以运行,但over-completed运行不了。也是跟psamask的实现有关:用到了自定义的torch.autograd.Function(里面用到了cpp文件,导入不了_C模块出错)
- #
- # from . import functions
- #
- #
- # def psa_mask(input, psa_type=0, mask_H_=None, mask_W_=None):
- # return functions.psa_mask(input, psa_type, mask_H_, mask_W_)
- #
- #
- # import torch
- # from torch.autograd import Function
- # from .. import src
-
-
- # class PSAMask(Function):
- # @staticmethod
- # def forward(ctx, input, psa_type=0, mask_H_=None, mask_W_=None):
- # assert psa_type in [0, 1] # 0-col, 1-dis
- # assert (mask_H_ is None and mask_W_ is None) or (mask_H_ is not None and mask_W_ is not None)
- # num_, channels_, feature_H_, feature_W_ = input.size()
- # if mask_H_ is None and mask_W_ is None:
- # mask_H_, mask_W_ = 2 * feature_H_ - 1, 2 * feature_W_ - 1
- # assert (mask_H_ % 2 == 1) and (mask_W_ % 2 == 1)
- # assert channels_ == mask_H_ * mask_W_
- # half_mask_H_, half_mask_W_ = (mask_H_ - 1) // 2, (mask_W_ - 1) // 2
- # output = torch.zeros([num_, feature_H_ * feature_W_, feature_H_, feature_W_], dtype=input.dtype, device=input.device)
- # if not input.is_cuda:
- # src.cpu.psamask_forward(psa_type, input, output, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
- # else:
- # output = output.cuda()
- # src.gpu.psamask_forward(psa_type, input, output, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
- # ctx.psa_type, ctx.num_, ctx.channels_, ctx.feature_H_, ctx.feature_W_ = psa_type, num_, channels_, feature_H_, feature_W_
- # ctx.mask_H_, ctx.mask_W_, ctx.half_mask_H_, ctx.half_mask_W_ = mask_H_, mask_W_, half_mask_H_, half_mask_W_
- # return output
- #
- # @staticmethod
- # def backward(ctx, grad_output):
- # psa_type, num_, channels_, feature_H_, feature_W_ = ctx.psa_type, ctx.num_, ctx.channels_, ctx.feature_H_, ctx.feature_W_
- # mask_H_, mask_W_, half_mask_H_, half_mask_W_ = ctx.mask_H_, ctx.mask_W_, ctx.half_mask_H_, ctx.half_mask_W_
- # grad_input = torch.zeros([num_, channels_, feature_H_, feature_W_], dtype=grad_output.dtype, device=grad_output.device)
- # if not grad_output.is_cuda:
- # src.cpu.psamask_backward(psa_type, grad_output, grad_input, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
- # else:
- # src.gpu.psamask_backward(psa_type, grad_output, grad_input, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
- # return grad_input, None, None, None
-
-
- # psa_mask = PSAMask.apply
-
-
- class PSA(nn.Module):
- def __init__(self, in_channels=2048, mid_channels=512, psa_type=2, compact=False, shrink_factor=2, mask_h=59,
- mask_w=59, normalization_factor=1.0, psa_softmax=True):
- super(PSA, self).__init__()
- assert psa_type in [0, 1, 2]
- self.psa_type = psa_type
- self.compact = compact
- self.shrink_factor = shrink_factor
- self.mask_h = mask_h
- self.mask_w = mask_w
- self.psa_softmax = psa_softmax
- if normalization_factor is None:
- normalization_factor = mask_h * mask_w
- self.normalization_factor = normalization_factor
-
- self.reduce = nn.Sequential(
- nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
- nn.BatchNorm2d(mid_channels),
- nn.ReLU(inplace=True)
- )
- self.attention = nn.Sequential(
- nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False),
- nn.BatchNorm2d(mid_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(mid_channels, mask_h*mask_w, kernel_size=1, bias=False),
- )
- if psa_type == 2:
- self.reduce_p = nn.Sequential(
- nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
- nn.BatchNorm2d(mid_channels),
- nn.ReLU(inplace=True)
- )
- self.attention_p = nn.Sequential(
- nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False),
- nn.BatchNorm2d(mid_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(mid_channels, mask_h*mask_w, kernel_size=1, bias=False),
- )
- self.proj = nn.Sequential(
- nn.Conv2d(mid_channels * (2 if psa_type == 2 else 1), in_channels, kernel_size=1, bias=False),
- nn.BatchNorm2d(in_channels),
- nn.ReLU(inplace=True)
- )
-
- def forward(self, x):
- out = x
- if self.psa_type in [0, 1]:
- x = self.reduce(x)
- n, c, h, w = x.size()
- if self.shrink_factor != 1:
- h = (h - 1) // self.shrink_factor + 1#可以理解为这样做的目的是向上取整。
- w = (w - 1) // self.shrink_factor + 1
- x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
- y = self.attention(x)
- if self.compact:
- if self.psa_type == 1:
- y = y.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
- else:
- y = PF.psa_mask(y, self.psa_type, self.mask_h, self.mask_w)
- if self.psa_softmax:
- y = F.softmax(y, dim=1)
- x = torch.bmm(x.view(n, c, h * w), y.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
- elif self.psa_type == 2:
- x_col = self.reduce(x)
- x_dis = self.reduce_p(x)
- n, c, h, w = x_col.size()
- if self.shrink_factor != 1:
- h = (h - 1) // self.shrink_factor + 1
- w = (w - 1) // self.shrink_factor + 1
- x_col = F.interpolate(x_col, size=(h, w), mode='bilinear', align_corners=True)
- x_dis = F.interpolate(x_dis, size=(h, w), mode='bilinear', align_corners=True)
- y_col = self.attention(x_col)
- y_dis = self.attention_p(x_dis)
- if self.compact:
- y_dis = y_dis.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
- else:
- y_col = PF.psa_mask(y_col, 0, self.mask_h, self.mask_w)
- y_dis = PF.psa_mask(y_dis, 1, self.mask_h, self.mask_w)
- if self.psa_softmax:
- y_col = F.softmax(y_col, dim=1)
- y_dis = F.softmax(y_dis, dim=1)
- x_col = torch.bmm(x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
- x_dis = torch.bmm(x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
- x = torch.cat([x_col, x_dis], 1)
- x = self.proj(x)
- if self.shrink_factor != 1:
- h = (h - 1) * self.shrink_factor + 1
- w = (w - 1) * self.shrink_factor + 1
- x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
- return torch.cat((out, x), 1)
-
-
- class PSANet(nn.Module):
- def __init__(self, layers=50, dropout=0.1, classes=2, zoom_factor=8, use_psa=True, psa_type=2, compact=False,
- shrink_factor=2, mask_h=59, mask_w=59, normalization_factor=1.0, psa_softmax=True,
- criterion=nn.CrossEntropyLoss(ignore_index=255), pretrained=True):
- super(PSANet, self).__init__()
- assert layers in [50, 101, 152]
- assert classes > 1
- assert zoom_factor in [1, 2, 4, 8]
- assert psa_type in [0, 1, 2]
- self.zoom_factor = zoom_factor
- self.use_psa = use_psa
- self.criterion = criterion
-
- if layers == 50:
- resnet = models.resnet50(pretrained=pretrained,deep_base=True)
- elif layers == 101:
- resnet = models.resnet101(pretrained=pretrained)
- else:
- resnet = models.resnet152(pretrained=pretrained)
- self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool)
- self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
-
- for n, m in self.layer3.named_modules():
- if 'conv2' in n:
- m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
- elif 'downsample.0' in n:
- m.stride = (1, 1)
- for n, m in self.layer4.named_modules():
- if 'conv2' in n:
- m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
- elif 'downsample.0' in n:
- m.stride = (1, 1)
-
- fea_dim = 2048
- if use_psa:
- self.psa = PSA(fea_dim, 512, psa_type, compact, shrink_factor, mask_h, mask_w, normalization_factor, psa_softmax)
- fea_dim *= 2
- self.cls = nn.Sequential(
- nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
- nn.BatchNorm2d(512),
- nn.ReLU(inplace=True),
- nn.Dropout2d(p=dropout),
- nn.Conv2d(512, classes, kernel_size=1)
- )
- if self.training:
- self.aux = nn.Sequential(
- nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Dropout2d(p=dropout),
- nn.Conv2d(256, classes, kernel_size=1)
- )
-
- def forward(self, x, y=None):
- x_size = x.size()
- assert (x_size[2] - 1) % 8 == 0 and (x_size[3] - 1) % 8 == 0
- h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
- w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1)
-
- x = self.layer0(x)
- x = self.layer1(x)
- x = self.layer2(x)
- x_tmp = self.layer3(x)
- x = self.layer4(x_tmp)
- if self.use_psa:
- x = self.psa(x)
- x = self.cls(x)
- if self.zoom_factor != 1:
- x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
-
- if self.training:
- aux = self.aux(x_tmp)
- if self.zoom_factor != 1:
- aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
- main_loss = self.criterion(x, y)
- aux_loss = self.criterion(aux, y)
- return x.max(1)[1], main_loss, aux_loss
- else:
- return x
-
-
- if __name__ == '__main__':
- import os
- os.environ["CUDA_VISIBLE_DEVICES"] = '0'
- crop_h = crop_w = 465
- input = torch.rand(4, 3, crop_h, crop_w).cuda()
- compact = False
- mask_h, mask_w = None, None
- shrink_factor = 2
- if compact:
- mask_h = (crop_h - 1) // (8 * shrink_factor) + 1
- mask_w = (crop_w - 1) // (8 * shrink_factor) + 1
- else:
- assert (mask_h is None and mask_w is None) or (mask_h is not None and mask_w is not None)
- if mask_h is None and mask_w is None:
- mask_h = 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1
- mask_w = 2 * ((crop_w - 1) // (8 * shrink_factor) + 1) - 1
- else:
- assert (mask_h % 2 == 1) and (mask_h >= 3) and (mask_h <= 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1)
- assert (mask_w % 2 == 1) and (mask_w >= 3) and (mask_w <= 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1)
-
- model = PSANet(layers=50, dropout=0.1, classes=21, zoom_factor=8, use_psa=True, psa_type=2, compact=compact,
- shrink_factor=shrink_factor, mask_h=mask_h, mask_w=mask_w, psa_softmax=True, pretrained=False).cuda()
- print(model)
- model.eval()
- output = model(input)
- print('PSANet', output.size())
|