AIlib2/segutils/core/nn/basic.py

135 lines
5.0 KiB
Python
Raw Normal View History

2025-04-26 10:35:59 +08:00
"""Basic Module for Semantic Segmentation"""
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['_ConvBNPReLU', '_ConvBN', '_BNPReLU', '_ConvBNReLU', '_DepthwiseConv', 'InvertedResidual']
class _ConvBNReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d, **kwargs):
super(_ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
self.bn = norm_layer(out_channels)
self.relu = nn.ReLU6(False) if relu6 else nn.ReLU(False)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class _ConvBNPReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs):
super(_ConvBNPReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
self.bn = norm_layer(out_channels)
self.prelu = nn.PReLU(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.prelu(x)
return x
class _ConvBN(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs):
super(_ConvBN, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
self.bn = norm_layer(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class _BNPReLU(nn.Module):
def __init__(self, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
super(_BNPReLU, self).__init__()
self.bn = norm_layer(out_channels)
self.prelu = nn.PReLU(out_channels)
def forward(self, x):
x = self.bn(x)
x = self.prelu(x)
return x
# -----------------------------------------------------------------
# For PSPNet
# -----------------------------------------------------------------
class _PSPModule(nn.Module):
def __init__(self, in_channels, sizes=(1, 2, 3, 6), **kwargs):
super(_PSPModule, self).__init__()
out_channels = int(in_channels / 4)
self.avgpools = nn.ModuleList()
self.convs = nn.ModuleList()
for size in sizes:
self.avgpool.append(nn.AdaptiveAvgPool2d(size))
self.convs.append(_ConvBNReLU(in_channels, out_channels, 1, **kwargs))
def forward(self, x):
size = x.size()[2:]
feats = [x]
for (avgpool, conv) in enumerate(zip(self.avgpools, self.convs)):
feats.append(F.interpolate(conv(avgpool(x)), size, mode='bilinear', align_corners=True))
return torch.cat(feats, dim=1)
# -----------------------------------------------------------------
# For MobileNet
# -----------------------------------------------------------------
class _DepthwiseConv(nn.Module):
"""conv_dw in MobileNet"""
def __init__(self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs):
super(_DepthwiseConv, self).__init__()
self.conv = nn.Sequential(
_ConvBNReLU(in_channels, in_channels, 3, stride, 1, groups=in_channels, norm_layer=norm_layer),
_ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer))
def forward(self, x):
return self.conv(x)
# -----------------------------------------------------------------
# For MobileNetV2
# -----------------------------------------------------------------
class InvertedResidual(nn.Module):
def __init__(self, in_channels, out_channels, stride, expand_ratio, norm_layer=nn.BatchNorm2d, **kwargs):
super(InvertedResidual, self).__init__()
assert stride in [1, 2]
self.use_res_connect = stride == 1 and in_channels == out_channels
layers = list()
inter_channels = int(round(in_channels * expand_ratio))
if expand_ratio != 1:
# pw
layers.append(_ConvBNReLU(in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer))
layers.extend([
# dw
_ConvBNReLU(inter_channels, inter_channels, 3, stride, 1,
groups=inter_channels, relu6=True, norm_layer=norm_layer),
# pw-linear
nn.Conv2d(inter_channels, out_channels, 1, bias=False),
norm_layer(out_channels)])
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
if __name__ == '__main__':
x = torch.randn(1, 32, 64, 64)
model = InvertedResidual(32, 64, 2, 1)
out = model(x)