AIlib2/segutils/core/models/psanet_offical.py

256 lines
12 KiB
Python
Raw Normal View History

2025-04-26 10:35:59 +08:00
import torch
from torch import nn
import torch.nn.functional as F
import core.lib.psa.functional as PF
import modeling.backbone.resnet_real as models
#运行失败compact可以运行但over-completed运行不了。也是跟psamask的实现有关用到了自定义的torch.autograd.Function里面用到了cpp文件导入不了_C模块出错
#
# from . import functions
#
#
# def psa_mask(input, psa_type=0, mask_H_=None, mask_W_=None):
# return functions.psa_mask(input, psa_type, mask_H_, mask_W_)
#
#
# import torch
# from torch.autograd import Function
# from .. import src
# class PSAMask(Function):
# @staticmethod
# def forward(ctx, input, psa_type=0, mask_H_=None, mask_W_=None):
# assert psa_type in [0, 1] # 0-col, 1-dis
# assert (mask_H_ is None and mask_W_ is None) or (mask_H_ is not None and mask_W_ is not None)
# num_, channels_, feature_H_, feature_W_ = input.size()
# if mask_H_ is None and mask_W_ is None:
# mask_H_, mask_W_ = 2 * feature_H_ - 1, 2 * feature_W_ - 1
# assert (mask_H_ % 2 == 1) and (mask_W_ % 2 == 1)
# assert channels_ == mask_H_ * mask_W_
# half_mask_H_, half_mask_W_ = (mask_H_ - 1) // 2, (mask_W_ - 1) // 2
# output = torch.zeros([num_, feature_H_ * feature_W_, feature_H_, feature_W_], dtype=input.dtype, device=input.device)
# if not input.is_cuda:
# src.cpu.psamask_forward(psa_type, input, output, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
# else:
# output = output.cuda()
# src.gpu.psamask_forward(psa_type, input, output, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
# ctx.psa_type, ctx.num_, ctx.channels_, ctx.feature_H_, ctx.feature_W_ = psa_type, num_, channels_, feature_H_, feature_W_
# ctx.mask_H_, ctx.mask_W_, ctx.half_mask_H_, ctx.half_mask_W_ = mask_H_, mask_W_, half_mask_H_, half_mask_W_
# return output
#
# @staticmethod
# def backward(ctx, grad_output):
# psa_type, num_, channels_, feature_H_, feature_W_ = ctx.psa_type, ctx.num_, ctx.channels_, ctx.feature_H_, ctx.feature_W_
# mask_H_, mask_W_, half_mask_H_, half_mask_W_ = ctx.mask_H_, ctx.mask_W_, ctx.half_mask_H_, ctx.half_mask_W_
# grad_input = torch.zeros([num_, channels_, feature_H_, feature_W_], dtype=grad_output.dtype, device=grad_output.device)
# if not grad_output.is_cuda:
# src.cpu.psamask_backward(psa_type, grad_output, grad_input, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
# else:
# src.gpu.psamask_backward(psa_type, grad_output, grad_input, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
# return grad_input, None, None, None
# psa_mask = PSAMask.apply
class PSA(nn.Module):
def __init__(self, in_channels=2048, mid_channels=512, psa_type=2, compact=False, shrink_factor=2, mask_h=59,
mask_w=59, normalization_factor=1.0, psa_softmax=True):
super(PSA, self).__init__()
assert psa_type in [0, 1, 2]
self.psa_type = psa_type
self.compact = compact
self.shrink_factor = shrink_factor
self.mask_h = mask_h
self.mask_w = mask_w
self.psa_softmax = psa_softmax
if normalization_factor is None:
normalization_factor = mask_h * mask_w
self.normalization_factor = normalization_factor
self.reduce = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True)
)
self.attention = nn.Sequential(
nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, mask_h*mask_w, kernel_size=1, bias=False),
)
if psa_type == 2:
self.reduce_p = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True)
)
self.attention_p = nn.Sequential(
nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, mask_h*mask_w, kernel_size=1, bias=False),
)
self.proj = nn.Sequential(
nn.Conv2d(mid_channels * (2 if psa_type == 2 else 1), in_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
out = x
if self.psa_type in [0, 1]:
x = self.reduce(x)
n, c, h, w = x.size()
if self.shrink_factor != 1:
h = (h - 1) // self.shrink_factor + 1#可以理解为这样做的目的是向上取整。
w = (w - 1) // self.shrink_factor + 1
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
y = self.attention(x)
if self.compact:
if self.psa_type == 1:
y = y.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
else:
y = PF.psa_mask(y, self.psa_type, self.mask_h, self.mask_w)
if self.psa_softmax:
y = F.softmax(y, dim=1)
x = torch.bmm(x.view(n, c, h * w), y.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
elif self.psa_type == 2:
x_col = self.reduce(x)
x_dis = self.reduce_p(x)
n, c, h, w = x_col.size()
if self.shrink_factor != 1:
h = (h - 1) // self.shrink_factor + 1
w = (w - 1) // self.shrink_factor + 1
x_col = F.interpolate(x_col, size=(h, w), mode='bilinear', align_corners=True)
x_dis = F.interpolate(x_dis, size=(h, w), mode='bilinear', align_corners=True)
y_col = self.attention(x_col)
y_dis = self.attention_p(x_dis)
if self.compact:
y_dis = y_dis.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
else:
y_col = PF.psa_mask(y_col, 0, self.mask_h, self.mask_w)
y_dis = PF.psa_mask(y_dis, 1, self.mask_h, self.mask_w)
if self.psa_softmax:
y_col = F.softmax(y_col, dim=1)
y_dis = F.softmax(y_dis, dim=1)
x_col = torch.bmm(x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
x_dis = torch.bmm(x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
x = torch.cat([x_col, x_dis], 1)
x = self.proj(x)
if self.shrink_factor != 1:
h = (h - 1) * self.shrink_factor + 1
w = (w - 1) * self.shrink_factor + 1
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
return torch.cat((out, x), 1)
class PSANet(nn.Module):
def __init__(self, layers=50, dropout=0.1, classes=2, zoom_factor=8, use_psa=True, psa_type=2, compact=False,
shrink_factor=2, mask_h=59, mask_w=59, normalization_factor=1.0, psa_softmax=True,
criterion=nn.CrossEntropyLoss(ignore_index=255), pretrained=True):
super(PSANet, self).__init__()
assert layers in [50, 101, 152]
assert classes > 1
assert zoom_factor in [1, 2, 4, 8]
assert psa_type in [0, 1, 2]
self.zoom_factor = zoom_factor
self.use_psa = use_psa
self.criterion = criterion
if layers == 50:
resnet = models.resnet50(pretrained=pretrained,deep_base=True)
elif layers == 101:
resnet = models.resnet101(pretrained=pretrained)
else:
resnet = models.resnet152(pretrained=pretrained)
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool)
self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
for n, m in self.layer3.named_modules():
if 'conv2' in n:
m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
for n, m in self.layer4.named_modules():
if 'conv2' in n:
m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
fea_dim = 2048
if use_psa:
self.psa = PSA(fea_dim, 512, psa_type, compact, shrink_factor, mask_h, mask_w, normalization_factor, psa_softmax)
fea_dim *= 2
self.cls = nn.Sequential(
nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Dropout2d(p=dropout),
nn.Conv2d(512, classes, kernel_size=1)
)
if self.training:
self.aux = nn.Sequential(
nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout2d(p=dropout),
nn.Conv2d(256, classes, kernel_size=1)
)
def forward(self, x, y=None):
x_size = x.size()
assert (x_size[2] - 1) % 8 == 0 and (x_size[3] - 1) % 8 == 0
h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1)
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x_tmp = self.layer3(x)
x = self.layer4(x_tmp)
if self.use_psa:
x = self.psa(x)
x = self.cls(x)
if self.zoom_factor != 1:
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
if self.training:
aux = self.aux(x_tmp)
if self.zoom_factor != 1:
aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
main_loss = self.criterion(x, y)
aux_loss = self.criterion(aux, y)
return x.max(1)[1], main_loss, aux_loss
else:
return x
if __name__ == '__main__':
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
crop_h = crop_w = 465
input = torch.rand(4, 3, crop_h, crop_w).cuda()
compact = False
mask_h, mask_w = None, None
shrink_factor = 2
if compact:
mask_h = (crop_h - 1) // (8 * shrink_factor) + 1
mask_w = (crop_w - 1) // (8 * shrink_factor) + 1
else:
assert (mask_h is None and mask_w is None) or (mask_h is not None and mask_w is not None)
if mask_h is None and mask_w is None:
mask_h = 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1
mask_w = 2 * ((crop_w - 1) // (8 * shrink_factor) + 1) - 1
else:
assert (mask_h % 2 == 1) and (mask_h >= 3) and (mask_h <= 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1)
assert (mask_w % 2 == 1) and (mask_w >= 3) and (mask_w <= 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1)
model = PSANet(layers=50, dropout=0.1, classes=21, zoom_factor=8, use_psa=True, psa_type=2, compact=compact,
shrink_factor=shrink_factor, mask_h=mask_h, mask_w=mask_w, psa_softmax=True, pretrained=False).cuda()
print(model)
model.eval()
output = model(input)
print('PSANet', output.size())