您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

55 行
2.1KB

  1. """Universal network struture unit definition."""
  2. from torch import nn
  3. def define_squeeze_unit(basic_channel_size):
  4. """Define a 1x1 squeeze convolution with norm and activation."""
  5. conv = nn.Conv2d(2 * basic_channel_size, basic_channel_size, kernel_size=1,
  6. stride=1, padding=0, bias=False)
  7. norm = nn.BatchNorm2d(basic_channel_size)
  8. relu = nn.LeakyReLU(0.1)
  9. layers = [conv, norm, relu]
  10. return layers
  11. def define_expand_unit(basic_channel_size):
  12. """Define a 3x3 expand convolution with norm and activation."""
  13. conv = nn.Conv2d(basic_channel_size, 2 * basic_channel_size, kernel_size=3,
  14. stride=1, padding=1, bias=False)
  15. norm = nn.BatchNorm2d(2 * basic_channel_size)
  16. relu = nn.LeakyReLU(0.1)
  17. layers = [conv, norm, relu]
  18. return layers
  19. def define_halve_unit(basic_channel_size):
  20. """Define a 4x4 stride 2 expand convolution with norm and activation."""
  21. conv = nn.Conv2d(basic_channel_size, 2 * basic_channel_size, kernel_size=4,
  22. stride=2, padding=1, bias=False)
  23. norm = nn.BatchNorm2d(2 * basic_channel_size)
  24. relu = nn.LeakyReLU(0.1)
  25. layers = [conv, norm, relu]
  26. return layers
  27. def define_depthwise_expand_unit(basic_channel_size):
  28. """Define a 3x3 expand convolution with norm and activation."""
  29. conv1 = nn.Conv2d(basic_channel_size, 2 * basic_channel_size,
  30. kernel_size=1, stride=1, padding=0, bias=False)
  31. norm1 = nn.BatchNorm2d(2 * basic_channel_size)
  32. relu1 = nn.LeakyReLU(0.1)
  33. conv2 = nn.Conv2d(2 * basic_channel_size, 2 * basic_channel_size, kernel_size=3,
  34. stride=1, padding=1, bias=False, groups=2 * basic_channel_size)
  35. norm2 = nn.BatchNorm2d(2 * basic_channel_size)
  36. relu2 = nn.LeakyReLU(0.1)
  37. layers = [conv1, norm1, relu1, conv2, norm2, relu2]
  38. return layers
  39. def define_detector_block(basic_channel_size):
  40. """Define a unit composite of a squeeze and expand unit."""
  41. layers = []
  42. layers += define_squeeze_unit(basic_channel_size)
  43. layers += define_expand_unit(basic_channel_size)
  44. return layers