import torch from torch import nn import torch.nn.functional as F import core.lib.psa.functional as PF import modeling.backbone.resnet_real as models #运行失败,compact可以运行,但over-completed运行不了。也是跟psamask的实现有关:用到了自定义的torch.autograd.Function(里面用到了cpp文件,导入不了_C模块出错) # # from . import functions # # # def psa_mask(input, psa_type=0, mask_H_=None, mask_W_=None): # return functions.psa_mask(input, psa_type, mask_H_, mask_W_) # # # import torch # from torch.autograd import Function # from .. import src # class PSAMask(Function): # @staticmethod # def forward(ctx, input, psa_type=0, mask_H_=None, mask_W_=None): # assert psa_type in [0, 1] # 0-col, 1-dis # assert (mask_H_ is None and mask_W_ is None) or (mask_H_ is not None and mask_W_ is not None) # num_, channels_, feature_H_, feature_W_ = input.size() # if mask_H_ is None and mask_W_ is None: # mask_H_, mask_W_ = 2 * feature_H_ - 1, 2 * feature_W_ - 1 # assert (mask_H_ % 2 == 1) and (mask_W_ % 2 == 1) # assert channels_ == mask_H_ * mask_W_ # half_mask_H_, half_mask_W_ = (mask_H_ - 1) // 2, (mask_W_ - 1) // 2 # output = torch.zeros([num_, feature_H_ * feature_W_, feature_H_, feature_W_], dtype=input.dtype, device=input.device) # if not input.is_cuda: # src.cpu.psamask_forward(psa_type, input, output, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_) # else: # output = output.cuda() # src.gpu.psamask_forward(psa_type, input, output, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_) # ctx.psa_type, ctx.num_, ctx.channels_, ctx.feature_H_, ctx.feature_W_ = psa_type, num_, channels_, feature_H_, feature_W_ # ctx.mask_H_, ctx.mask_W_, ctx.half_mask_H_, ctx.half_mask_W_ = mask_H_, mask_W_, half_mask_H_, half_mask_W_ # return output # # @staticmethod # def backward(ctx, grad_output): # psa_type, num_, channels_, feature_H_, feature_W_ = ctx.psa_type, ctx.num_, ctx.channels_, ctx.feature_H_, ctx.feature_W_ # mask_H_, mask_W_, half_mask_H_, half_mask_W_ = ctx.mask_H_, ctx.mask_W_, ctx.half_mask_H_, ctx.half_mask_W_ # grad_input = torch.zeros([num_, channels_, feature_H_, feature_W_], dtype=grad_output.dtype, device=grad_output.device) # if not grad_output.is_cuda: # src.cpu.psamask_backward(psa_type, grad_output, grad_input, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_) # else: # src.gpu.psamask_backward(psa_type, grad_output, grad_input, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_) # return grad_input, None, None, None # psa_mask = PSAMask.apply class PSA(nn.Module): def __init__(self, in_channels=2048, mid_channels=512, psa_type=2, compact=False, shrink_factor=2, mask_h=59, mask_w=59, normalization_factor=1.0, psa_softmax=True): super(PSA, self).__init__() assert psa_type in [0, 1, 2] self.psa_type = psa_type self.compact = compact self.shrink_factor = shrink_factor self.mask_h = mask_h self.mask_w = mask_w self.psa_softmax = psa_softmax if normalization_factor is None: normalization_factor = mask_h * mask_w self.normalization_factor = normalization_factor self.reduce = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True) ) self.attention = nn.Sequential( nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, mask_h*mask_w, kernel_size=1, bias=False), ) if psa_type == 2: self.reduce_p = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True) ) self.attention_p = nn.Sequential( nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, mask_h*mask_w, kernel_size=1, bias=False), ) self.proj = nn.Sequential( nn.Conv2d(mid_channels * (2 if psa_type == 2 else 1), in_channels, kernel_size=1, bias=False), nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True) ) def forward(self, x): out = x if self.psa_type in [0, 1]: x = self.reduce(x) n, c, h, w = x.size() if self.shrink_factor != 1: h = (h - 1) // self.shrink_factor + 1#可以理解为这样做的目的是向上取整。 w = (w - 1) // self.shrink_factor + 1 x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) y = self.attention(x) if self.compact: if self.psa_type == 1: y = y.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w) else: y = PF.psa_mask(y, self.psa_type, self.mask_h, self.mask_w) if self.psa_softmax: y = F.softmax(y, dim=1) x = torch.bmm(x.view(n, c, h * w), y.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor) elif self.psa_type == 2: x_col = self.reduce(x) x_dis = self.reduce_p(x) n, c, h, w = x_col.size() if self.shrink_factor != 1: h = (h - 1) // self.shrink_factor + 1 w = (w - 1) // self.shrink_factor + 1 x_col = F.interpolate(x_col, size=(h, w), mode='bilinear', align_corners=True) x_dis = F.interpolate(x_dis, size=(h, w), mode='bilinear', align_corners=True) y_col = self.attention(x_col) y_dis = self.attention_p(x_dis) if self.compact: y_dis = y_dis.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w) else: y_col = PF.psa_mask(y_col, 0, self.mask_h, self.mask_w) y_dis = PF.psa_mask(y_dis, 1, self.mask_h, self.mask_w) if self.psa_softmax: y_col = F.softmax(y_col, dim=1) y_dis = F.softmax(y_dis, dim=1) x_col = torch.bmm(x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor) x_dis = torch.bmm(x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor) x = torch.cat([x_col, x_dis], 1) x = self.proj(x) if self.shrink_factor != 1: h = (h - 1) * self.shrink_factor + 1 w = (w - 1) * self.shrink_factor + 1 x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) return torch.cat((out, x), 1) class PSANet(nn.Module): def __init__(self, layers=50, dropout=0.1, classes=2, zoom_factor=8, use_psa=True, psa_type=2, compact=False, shrink_factor=2, mask_h=59, mask_w=59, normalization_factor=1.0, psa_softmax=True, criterion=nn.CrossEntropyLoss(ignore_index=255), pretrained=True): super(PSANet, self).__init__() assert layers in [50, 101, 152] assert classes > 1 assert zoom_factor in [1, 2, 4, 8] assert psa_type in [0, 1, 2] self.zoom_factor = zoom_factor self.use_psa = use_psa self.criterion = criterion if layers == 50: resnet = models.resnet50(pretrained=pretrained,deep_base=True) elif layers == 101: resnet = models.resnet101(pretrained=pretrained) else: resnet = models.resnet152(pretrained=pretrained) self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool) self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) fea_dim = 2048 if use_psa: self.psa = PSA(fea_dim, 512, psa_type, compact, shrink_factor, mask_h, mask_w, normalization_factor, psa_softmax) fea_dim *= 2 self.cls = nn.Sequential( nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Dropout2d(p=dropout), nn.Conv2d(512, classes, kernel_size=1) ) if self.training: self.aux = nn.Sequential( nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout2d(p=dropout), nn.Conv2d(256, classes, kernel_size=1) ) def forward(self, x, y=None): x_size = x.size() assert (x_size[2] - 1) % 8 == 0 and (x_size[3] - 1) % 8 == 0 h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1) w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1) x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) x_tmp = self.layer3(x) x = self.layer4(x_tmp) if self.use_psa: x = self.psa(x) x = self.cls(x) if self.zoom_factor != 1: x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) if self.training: aux = self.aux(x_tmp) if self.zoom_factor != 1: aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True) main_loss = self.criterion(x, y) aux_loss = self.criterion(aux, y) return x.max(1)[1], main_loss, aux_loss else: return x if __name__ == '__main__': import os os.environ["CUDA_VISIBLE_DEVICES"] = '0' crop_h = crop_w = 465 input = torch.rand(4, 3, crop_h, crop_w).cuda() compact = False mask_h, mask_w = None, None shrink_factor = 2 if compact: mask_h = (crop_h - 1) // (8 * shrink_factor) + 1 mask_w = (crop_w - 1) // (8 * shrink_factor) + 1 else: assert (mask_h is None and mask_w is None) or (mask_h is not None and mask_w is not None) if mask_h is None and mask_w is None: mask_h = 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1 mask_w = 2 * ((crop_w - 1) // (8 * shrink_factor) + 1) - 1 else: assert (mask_h % 2 == 1) and (mask_h >= 3) and (mask_h <= 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1) assert (mask_w % 2 == 1) and (mask_w >= 3) and (mask_w <= 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1) model = PSANet(layers=50, dropout=0.1, classes=21, zoom_factor=8, use_psa=True, psa_type=2, compact=compact, shrink_factor=shrink_factor, mask_h=mask_h, mask_w=mask_w, psa_softmax=True, pretrained=False).cuda() print(model) model.eval() output = model(input) print('PSANet', output.size())