DMPR-PS/detector.py

64 lines
2.5 KiB
Python

# -*- coding: utf-8 -*-
import torch
from torch import nn
from network import define_halve_unit, define_detector_block
class YetAnotherDarknet(nn.modules.Module):
"""Yet another darknet, imitating darknet-53 with depth of darknet-19."""
def __init__(self, input_channel_size, depth_factor):
super(YetAnotherDarknet, self).__init__()
layers = []
# 0
layers += [nn.Conv2d(input_channel_size, depth_factor, kernel_size=3,
stride=1, padding=1, bias=False)]
layers += [nn.BatchNorm2d(depth_factor)]
layers += [nn.LeakyReLU(0.1)]
# 1
layers += define_halve_unit(depth_factor)
layers += define_detector_block(depth_factor)
# 2
depth_factor *= 2
layers += define_halve_unit(depth_factor)
layers += define_detector_block(depth_factor)
# 3
depth_factor *= 2
layers += define_halve_unit(depth_factor)
layers += define_detector_block(depth_factor)
layers += define_detector_block(depth_factor)
# 4
depth_factor *= 2
layers += define_halve_unit(depth_factor)
layers += define_detector_block(depth_factor)
layers += define_detector_block(depth_factor)
# 5
depth_factor *= 2
layers += define_halve_unit(depth_factor)
layers += define_detector_block(depth_factor)
self.model = nn.Sequential(*layers)
def forward(self, *x):
return self.model(x[0])
class DirectionalPointDetector(nn.modules.Module):
"""Detector for point with direction."""
def __init__(self, input_channel_size, depth_factor, output_channel_size):
super(DirectionalPointDetector, self).__init__()
self.extract_feature = YetAnotherDarknet(input_channel_size,
depth_factor)
layers = []
layers += define_detector_block(16 * depth_factor)
layers += define_detector_block(16 * depth_factor)
layers += [nn.Conv2d(32 * depth_factor, output_channel_size,
kernel_size=1, stride=1, padding=0, bias=False)]
self.predict = nn.Sequential(*layers)
def forward(self, *x):
feature = self.extract_feature(x[0])
prediction = self.predict(feature)
point_pred, angle_pred = torch.split(prediction, 3, dim=1)
point_pred = nn.functional.sigmoid(point_pred)
angle_pred = nn.functional.tanh(angle_pred)
return torch.cat((point_pred, angle_pred), dim=1)