|
- """Defines the detector network structure."""
- import torch
- from torch import nn
- from DMPRUtils.model.network import define_halve_unit, define_detector_block
-
-
- class YetAnotherDarknet(nn.modules.Module):
- """Yet another darknet, imitating darknet-53 with depth of darknet-19."""
- def __init__(self, input_channel_size, depth_factor):
- super(YetAnotherDarknet, self).__init__()
- layers = []
- # 0
- layers += [nn.Conv2d(input_channel_size, depth_factor, kernel_size=3,
- stride=1, padding=1, bias=False)]
- layers += [nn.BatchNorm2d(depth_factor)]
- layers += [nn.LeakyReLU(0.1)]
- # 1
- layers += define_halve_unit(depth_factor)
- layers += define_detector_block(depth_factor)
- # 2
- depth_factor *= 2
- layers += define_halve_unit(depth_factor)
- layers += define_detector_block(depth_factor)
- # 3
- depth_factor *= 2
- layers += define_halve_unit(depth_factor)
- layers += define_detector_block(depth_factor)
- layers += define_detector_block(depth_factor)
- # 4
- depth_factor *= 2
- layers += define_halve_unit(depth_factor)
- layers += define_detector_block(depth_factor)
- layers += define_detector_block(depth_factor)
- # 5
- depth_factor *= 2
- layers += define_halve_unit(depth_factor)
- layers += define_detector_block(depth_factor)
- self.model = nn.Sequential(*layers)
-
- def forward(self, *x):
- return self.model(x[0])
-
-
- class DirectionalPointDetector(nn.modules.Module):
- """Detector for point with direction."""
- def __init__(self, input_channel_size, depth_factor, output_channel_size):
- super(DirectionalPointDetector, self).__init__()
- self.extract_feature = YetAnotherDarknet(input_channel_size,
- depth_factor)
- layers = []
- layers += define_detector_block(16 * depth_factor)
- layers += define_detector_block(16 * depth_factor)
- layers += [nn.Conv2d(32 * depth_factor, output_channel_size,
- kernel_size=1, stride=1, padding=0, bias=False)]
- self.predict = nn.Sequential(*layers)
-
- def forward(self, *x):
- prediction = self.predict(self.extract_feature(x[0]))
- # 4 represents that there are 4 value: confidence, shape, offset_x,
- # offset_y, whose range is between [0, 1].
- point_pred, angle_pred = torch.split(prediction, 4, dim=1)
- point_pred = torch.sigmoid(point_pred)
- angle_pred = torch.tanh(angle_pred)
- return torch.cat((point_pred, angle_pred), dim=1)
|