256 lines
12 KiB
Python
256 lines
12 KiB
Python
import torch
|
||
from torch import nn
|
||
import torch.nn.functional as F
|
||
import core.lib.psa.functional as PF
|
||
import modeling.backbone.resnet_real as models
|
||
|
||
#运行失败,compact可以运行,但over-completed运行不了。也是跟psamask的实现有关:用到了自定义的torch.autograd.Function(里面用到了cpp文件,导入不了_C模块出错)
|
||
#
|
||
# from . import functions
|
||
#
|
||
#
|
||
# def psa_mask(input, psa_type=0, mask_H_=None, mask_W_=None):
|
||
# return functions.psa_mask(input, psa_type, mask_H_, mask_W_)
|
||
#
|
||
#
|
||
# import torch
|
||
# from torch.autograd import Function
|
||
# from .. import src
|
||
|
||
|
||
# class PSAMask(Function):
|
||
# @staticmethod
|
||
# def forward(ctx, input, psa_type=0, mask_H_=None, mask_W_=None):
|
||
# assert psa_type in [0, 1] # 0-col, 1-dis
|
||
# assert (mask_H_ is None and mask_W_ is None) or (mask_H_ is not None and mask_W_ is not None)
|
||
# num_, channels_, feature_H_, feature_W_ = input.size()
|
||
# if mask_H_ is None and mask_W_ is None:
|
||
# mask_H_, mask_W_ = 2 * feature_H_ - 1, 2 * feature_W_ - 1
|
||
# assert (mask_H_ % 2 == 1) and (mask_W_ % 2 == 1)
|
||
# assert channels_ == mask_H_ * mask_W_
|
||
# half_mask_H_, half_mask_W_ = (mask_H_ - 1) // 2, (mask_W_ - 1) // 2
|
||
# output = torch.zeros([num_, feature_H_ * feature_W_, feature_H_, feature_W_], dtype=input.dtype, device=input.device)
|
||
# if not input.is_cuda:
|
||
# src.cpu.psamask_forward(psa_type, input, output, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
|
||
# else:
|
||
# output = output.cuda()
|
||
# src.gpu.psamask_forward(psa_type, input, output, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
|
||
# ctx.psa_type, ctx.num_, ctx.channels_, ctx.feature_H_, ctx.feature_W_ = psa_type, num_, channels_, feature_H_, feature_W_
|
||
# ctx.mask_H_, ctx.mask_W_, ctx.half_mask_H_, ctx.half_mask_W_ = mask_H_, mask_W_, half_mask_H_, half_mask_W_
|
||
# return output
|
||
#
|
||
# @staticmethod
|
||
# def backward(ctx, grad_output):
|
||
# psa_type, num_, channels_, feature_H_, feature_W_ = ctx.psa_type, ctx.num_, ctx.channels_, ctx.feature_H_, ctx.feature_W_
|
||
# mask_H_, mask_W_, half_mask_H_, half_mask_W_ = ctx.mask_H_, ctx.mask_W_, ctx.half_mask_H_, ctx.half_mask_W_
|
||
# grad_input = torch.zeros([num_, channels_, feature_H_, feature_W_], dtype=grad_output.dtype, device=grad_output.device)
|
||
# if not grad_output.is_cuda:
|
||
# src.cpu.psamask_backward(psa_type, grad_output, grad_input, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
|
||
# else:
|
||
# src.gpu.psamask_backward(psa_type, grad_output, grad_input, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
|
||
# return grad_input, None, None, None
|
||
|
||
|
||
# psa_mask = PSAMask.apply
|
||
|
||
|
||
class PSA(nn.Module):
|
||
def __init__(self, in_channels=2048, mid_channels=512, psa_type=2, compact=False, shrink_factor=2, mask_h=59,
|
||
mask_w=59, normalization_factor=1.0, psa_softmax=True):
|
||
super(PSA, self).__init__()
|
||
assert psa_type in [0, 1, 2]
|
||
self.psa_type = psa_type
|
||
self.compact = compact
|
||
self.shrink_factor = shrink_factor
|
||
self.mask_h = mask_h
|
||
self.mask_w = mask_w
|
||
self.psa_softmax = psa_softmax
|
||
if normalization_factor is None:
|
||
normalization_factor = mask_h * mask_w
|
||
self.normalization_factor = normalization_factor
|
||
|
||
self.reduce = nn.Sequential(
|
||
nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
|
||
nn.BatchNorm2d(mid_channels),
|
||
nn.ReLU(inplace=True)
|
||
)
|
||
self.attention = nn.Sequential(
|
||
nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False),
|
||
nn.BatchNorm2d(mid_channels),
|
||
nn.ReLU(inplace=True),
|
||
nn.Conv2d(mid_channels, mask_h*mask_w, kernel_size=1, bias=False),
|
||
)
|
||
if psa_type == 2:
|
||
self.reduce_p = nn.Sequential(
|
||
nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
|
||
nn.BatchNorm2d(mid_channels),
|
||
nn.ReLU(inplace=True)
|
||
)
|
||
self.attention_p = nn.Sequential(
|
||
nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False),
|
||
nn.BatchNorm2d(mid_channels),
|
||
nn.ReLU(inplace=True),
|
||
nn.Conv2d(mid_channels, mask_h*mask_w, kernel_size=1, bias=False),
|
||
)
|
||
self.proj = nn.Sequential(
|
||
nn.Conv2d(mid_channels * (2 if psa_type == 2 else 1), in_channels, kernel_size=1, bias=False),
|
||
nn.BatchNorm2d(in_channels),
|
||
nn.ReLU(inplace=True)
|
||
)
|
||
|
||
def forward(self, x):
|
||
out = x
|
||
if self.psa_type in [0, 1]:
|
||
x = self.reduce(x)
|
||
n, c, h, w = x.size()
|
||
if self.shrink_factor != 1:
|
||
h = (h - 1) // self.shrink_factor + 1#可以理解为这样做的目的是向上取整。
|
||
w = (w - 1) // self.shrink_factor + 1
|
||
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
|
||
y = self.attention(x)
|
||
if self.compact:
|
||
if self.psa_type == 1:
|
||
y = y.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
|
||
else:
|
||
y = PF.psa_mask(y, self.psa_type, self.mask_h, self.mask_w)
|
||
if self.psa_softmax:
|
||
y = F.softmax(y, dim=1)
|
||
x = torch.bmm(x.view(n, c, h * w), y.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
|
||
elif self.psa_type == 2:
|
||
x_col = self.reduce(x)
|
||
x_dis = self.reduce_p(x)
|
||
n, c, h, w = x_col.size()
|
||
if self.shrink_factor != 1:
|
||
h = (h - 1) // self.shrink_factor + 1
|
||
w = (w - 1) // self.shrink_factor + 1
|
||
x_col = F.interpolate(x_col, size=(h, w), mode='bilinear', align_corners=True)
|
||
x_dis = F.interpolate(x_dis, size=(h, w), mode='bilinear', align_corners=True)
|
||
y_col = self.attention(x_col)
|
||
y_dis = self.attention_p(x_dis)
|
||
if self.compact:
|
||
y_dis = y_dis.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
|
||
else:
|
||
y_col = PF.psa_mask(y_col, 0, self.mask_h, self.mask_w)
|
||
y_dis = PF.psa_mask(y_dis, 1, self.mask_h, self.mask_w)
|
||
if self.psa_softmax:
|
||
y_col = F.softmax(y_col, dim=1)
|
||
y_dis = F.softmax(y_dis, dim=1)
|
||
x_col = torch.bmm(x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
|
||
x_dis = torch.bmm(x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
|
||
x = torch.cat([x_col, x_dis], 1)
|
||
x = self.proj(x)
|
||
if self.shrink_factor != 1:
|
||
h = (h - 1) * self.shrink_factor + 1
|
||
w = (w - 1) * self.shrink_factor + 1
|
||
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
|
||
return torch.cat((out, x), 1)
|
||
|
||
|
||
class PSANet(nn.Module):
|
||
def __init__(self, layers=50, dropout=0.1, classes=2, zoom_factor=8, use_psa=True, psa_type=2, compact=False,
|
||
shrink_factor=2, mask_h=59, mask_w=59, normalization_factor=1.0, psa_softmax=True,
|
||
criterion=nn.CrossEntropyLoss(ignore_index=255), pretrained=True):
|
||
super(PSANet, self).__init__()
|
||
assert layers in [50, 101, 152]
|
||
assert classes > 1
|
||
assert zoom_factor in [1, 2, 4, 8]
|
||
assert psa_type in [0, 1, 2]
|
||
self.zoom_factor = zoom_factor
|
||
self.use_psa = use_psa
|
||
self.criterion = criterion
|
||
|
||
if layers == 50:
|
||
resnet = models.resnet50(pretrained=pretrained,deep_base=True)
|
||
elif layers == 101:
|
||
resnet = models.resnet101(pretrained=pretrained)
|
||
else:
|
||
resnet = models.resnet152(pretrained=pretrained)
|
||
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool)
|
||
self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
|
||
|
||
for n, m in self.layer3.named_modules():
|
||
if 'conv2' in n:
|
||
m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
|
||
elif 'downsample.0' in n:
|
||
m.stride = (1, 1)
|
||
for n, m in self.layer4.named_modules():
|
||
if 'conv2' in n:
|
||
m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
|
||
elif 'downsample.0' in n:
|
||
m.stride = (1, 1)
|
||
|
||
fea_dim = 2048
|
||
if use_psa:
|
||
self.psa = PSA(fea_dim, 512, psa_type, compact, shrink_factor, mask_h, mask_w, normalization_factor, psa_softmax)
|
||
fea_dim *= 2
|
||
self.cls = nn.Sequential(
|
||
nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
|
||
nn.BatchNorm2d(512),
|
||
nn.ReLU(inplace=True),
|
||
nn.Dropout2d(p=dropout),
|
||
nn.Conv2d(512, classes, kernel_size=1)
|
||
)
|
||
if self.training:
|
||
self.aux = nn.Sequential(
|
||
nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False),
|
||
nn.BatchNorm2d(256),
|
||
nn.ReLU(inplace=True),
|
||
nn.Dropout2d(p=dropout),
|
||
nn.Conv2d(256, classes, kernel_size=1)
|
||
)
|
||
|
||
def forward(self, x, y=None):
|
||
x_size = x.size()
|
||
assert (x_size[2] - 1) % 8 == 0 and (x_size[3] - 1) % 8 == 0
|
||
h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
|
||
w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1)
|
||
|
||
x = self.layer0(x)
|
||
x = self.layer1(x)
|
||
x = self.layer2(x)
|
||
x_tmp = self.layer3(x)
|
||
x = self.layer4(x_tmp)
|
||
if self.use_psa:
|
||
x = self.psa(x)
|
||
x = self.cls(x)
|
||
if self.zoom_factor != 1:
|
||
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
|
||
|
||
if self.training:
|
||
aux = self.aux(x_tmp)
|
||
if self.zoom_factor != 1:
|
||
aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
|
||
main_loss = self.criterion(x, y)
|
||
aux_loss = self.criterion(aux, y)
|
||
return x.max(1)[1], main_loss, aux_loss
|
||
else:
|
||
return x
|
||
|
||
|
||
if __name__ == '__main__':
|
||
import os
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
||
crop_h = crop_w = 465
|
||
input = torch.rand(4, 3, crop_h, crop_w).cuda()
|
||
compact = False
|
||
mask_h, mask_w = None, None
|
||
shrink_factor = 2
|
||
if compact:
|
||
mask_h = (crop_h - 1) // (8 * shrink_factor) + 1
|
||
mask_w = (crop_w - 1) // (8 * shrink_factor) + 1
|
||
else:
|
||
assert (mask_h is None and mask_w is None) or (mask_h is not None and mask_w is not None)
|
||
if mask_h is None and mask_w is None:
|
||
mask_h = 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1
|
||
mask_w = 2 * ((crop_w - 1) // (8 * shrink_factor) + 1) - 1
|
||
else:
|
||
assert (mask_h % 2 == 1) and (mask_h >= 3) and (mask_h <= 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1)
|
||
assert (mask_w % 2 == 1) and (mask_w >= 3) and (mask_w <= 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1)
|
||
|
||
model = PSANet(layers=50, dropout=0.1, classes=21, zoom_factor=8, use_psa=True, psa_type=2, compact=compact,
|
||
shrink_factor=shrink_factor, mask_h=mask_h, mask_w=mask_w, psa_softmax=True, pretrained=False).cuda()
|
||
print(model)
|
||
model.eval()
|
||
output = model(input)
|
||
print('PSANet', output.size())
|