65 lines
2.7 KiB
Python
65 lines
2.7 KiB
Python
"""Defines the detector network structure."""
|
|
import torch
|
|
from torch import nn
|
|
from DMPRUtils.model.network import define_halve_unit, define_detector_block
|
|
|
|
|
|
class YetAnotherDarknet(nn.modules.Module):
|
|
"""Yet another darknet, imitating darknet-53 with depth of darknet-19."""
|
|
def __init__(self, input_channel_size, depth_factor):
|
|
super(YetAnotherDarknet, self).__init__()
|
|
layers = []
|
|
# 0
|
|
layers += [nn.Conv2d(input_channel_size, depth_factor, kernel_size=3,
|
|
stride=1, padding=1, bias=False)]
|
|
layers += [nn.BatchNorm2d(depth_factor)]
|
|
layers += [nn.LeakyReLU(0.1)]
|
|
# 1
|
|
layers += define_halve_unit(depth_factor)
|
|
layers += define_detector_block(depth_factor)
|
|
# 2
|
|
depth_factor *= 2
|
|
layers += define_halve_unit(depth_factor)
|
|
layers += define_detector_block(depth_factor)
|
|
# 3
|
|
depth_factor *= 2
|
|
layers += define_halve_unit(depth_factor)
|
|
layers += define_detector_block(depth_factor)
|
|
layers += define_detector_block(depth_factor)
|
|
# 4
|
|
depth_factor *= 2
|
|
layers += define_halve_unit(depth_factor)
|
|
layers += define_detector_block(depth_factor)
|
|
layers += define_detector_block(depth_factor)
|
|
# 5
|
|
depth_factor *= 2
|
|
layers += define_halve_unit(depth_factor)
|
|
layers += define_detector_block(depth_factor)
|
|
self.model = nn.Sequential(*layers)
|
|
|
|
def forward(self, *x):
|
|
return self.model(x[0])
|
|
|
|
|
|
class DirectionalPointDetector(nn.modules.Module):
|
|
"""Detector for point with direction."""
|
|
def __init__(self, input_channel_size, depth_factor, output_channel_size):
|
|
super(DirectionalPointDetector, self).__init__()
|
|
self.extract_feature = YetAnotherDarknet(input_channel_size,
|
|
depth_factor)
|
|
layers = []
|
|
layers += define_detector_block(16 * depth_factor)
|
|
layers += define_detector_block(16 * depth_factor)
|
|
layers += [nn.Conv2d(32 * depth_factor, output_channel_size,
|
|
kernel_size=1, stride=1, padding=0, bias=False)]
|
|
self.predict = nn.Sequential(*layers)
|
|
|
|
def forward(self, *x):
|
|
prediction = self.predict(self.extract_feature(x[0]))
|
|
# 4 represents that there are 4 value: confidence, shape, offset_x,
|
|
# offset_y, whose range is between [0, 1].
|
|
point_pred, angle_pred = torch.split(prediction, 4, dim=1)
|
|
point_pred = torch.sigmoid(point_pred)
|
|
angle_pred = torch.tanh(angle_pred)
|
|
return torch.cat((point_pred, angle_pred), dim=1)
|