"""Defines the detector network structure.""" import torch from torch import nn from DMPRUtils.model.network import define_halve_unit, define_detector_block class YetAnotherDarknet(nn.modules.Module): """Yet another darknet, imitating darknet-53 with depth of darknet-19.""" def __init__(self, input_channel_size, depth_factor): super(YetAnotherDarknet, self).__init__() layers = [] # 0 layers += [nn.Conv2d(input_channel_size, depth_factor, kernel_size=3, stride=1, padding=1, bias=False)] layers += [nn.BatchNorm2d(depth_factor)] layers += [nn.LeakyReLU(0.1)] # 1 layers += define_halve_unit(depth_factor) layers += define_detector_block(depth_factor) # 2 depth_factor *= 2 layers += define_halve_unit(depth_factor) layers += define_detector_block(depth_factor) # 3 depth_factor *= 2 layers += define_halve_unit(depth_factor) layers += define_detector_block(depth_factor) layers += define_detector_block(depth_factor) # 4 depth_factor *= 2 layers += define_halve_unit(depth_factor) layers += define_detector_block(depth_factor) layers += define_detector_block(depth_factor) # 5 depth_factor *= 2 layers += define_halve_unit(depth_factor) layers += define_detector_block(depth_factor) self.model = nn.Sequential(*layers) def forward(self, *x): return self.model(x[0]) class DirectionalPointDetector(nn.modules.Module): """Detector for point with direction.""" def __init__(self, input_channel_size, depth_factor, output_channel_size): super(DirectionalPointDetector, self).__init__() self.extract_feature = YetAnotherDarknet(input_channel_size, depth_factor) layers = [] layers += define_detector_block(16 * depth_factor) layers += define_detector_block(16 * depth_factor) layers += [nn.Conv2d(32 * depth_factor, output_channel_size, kernel_size=1, stride=1, padding=0, bias=False)] self.predict = nn.Sequential(*layers) def forward(self, *x): prediction = self.predict(self.extract_feature(x[0])) # 4 represents that there are 4 value: confidence, shape, offset_x, # offset_y, whose range is between [0, 1]. point_pred, angle_pred = torch.split(prediction, 4, dim=1) point_pred = torch.sigmoid(point_pred) angle_pred = torch.tanh(angle_pred) return torch.cat((point_pred, angle_pred), dim=1)