AIlib2/DMPRUtils/model/detector.py

65 lines
2.7 KiB
Python

"""Defines the detector network structure."""
import torch
from torch import nn
from DMPRUtils.model.network import define_halve_unit, define_detector_block
class YetAnotherDarknet(nn.modules.Module):
"""Yet another darknet, imitating darknet-53 with depth of darknet-19."""
def __init__(self, input_channel_size, depth_factor):
super(YetAnotherDarknet, self).__init__()
layers = []
# 0
layers += [nn.Conv2d(input_channel_size, depth_factor, kernel_size=3,
stride=1, padding=1, bias=False)]
layers += [nn.BatchNorm2d(depth_factor)]
layers += [nn.LeakyReLU(0.1)]
# 1
layers += define_halve_unit(depth_factor)
layers += define_detector_block(depth_factor)
# 2
depth_factor *= 2
layers += define_halve_unit(depth_factor)
layers += define_detector_block(depth_factor)
# 3
depth_factor *= 2
layers += define_halve_unit(depth_factor)
layers += define_detector_block(depth_factor)
layers += define_detector_block(depth_factor)
# 4
depth_factor *= 2
layers += define_halve_unit(depth_factor)
layers += define_detector_block(depth_factor)
layers += define_detector_block(depth_factor)
# 5
depth_factor *= 2
layers += define_halve_unit(depth_factor)
layers += define_detector_block(depth_factor)
self.model = nn.Sequential(*layers)
def forward(self, *x):
return self.model(x[0])
class DirectionalPointDetector(nn.modules.Module):
"""Detector for point with direction."""
def __init__(self, input_channel_size, depth_factor, output_channel_size):
super(DirectionalPointDetector, self).__init__()
self.extract_feature = YetAnotherDarknet(input_channel_size,
depth_factor)
layers = []
layers += define_detector_block(16 * depth_factor)
layers += define_detector_block(16 * depth_factor)
layers += [nn.Conv2d(32 * depth_factor, output_channel_size,
kernel_size=1, stride=1, padding=0, bias=False)]
self.predict = nn.Sequential(*layers)
def forward(self, *x):
prediction = self.predict(self.extract_feature(x[0]))
# 4 represents that there are 4 value: confidence, shape, offset_x,
# offset_y, whose range is between [0, 1].
point_pred, angle_pred = torch.split(prediction, 4, dim=1)
point_pred = torch.sigmoid(point_pred)
angle_pred = torch.tanh(angle_pred)
return torch.cat((point_pred, angle_pred), dim=1)