You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

163 lines
5.7KB

  1. """Utils for Semantic Segmentation"""
  2. import threading
  3. import torch
  4. import torch.cuda.comm as comm
  5. from torch.nn.parallel.data_parallel import DataParallel
  6. from torch.nn.parallel._functions import Broadcast
  7. from torch.autograd import Function
  8. __all__ = ['DataParallelModel', 'DataParallelCriterion']
  9. class Reduce(Function):
  10. @staticmethod
  11. def forward(ctx, *inputs):
  12. ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
  13. inputs = sorted(inputs, key=lambda i: i.get_device())
  14. return comm.reduce_add(inputs)
  15. @staticmethod
  16. def backward(ctx, gradOutputs):
  17. return Broadcast.apply(ctx.target_gpus, gradOutputs)
  18. class DataParallelModel(DataParallel):
  19. """Data parallelism
  20. Hide the difference of single/multiple GPUs to the user.
  21. In the forward pass, the module is replicated on each device,
  22. and each replica handles a portion of the input. During the backwards
  23. pass, gradients from each replica are summed into the original module.
  24. The batch size should be larger than the number of GPUs used.
  25. Parameters
  26. ----------
  27. module : object
  28. Network to be parallelized.
  29. sync : bool
  30. enable synchronization (default: False).
  31. Inputs:
  32. - **inputs**: list of input
  33. Outputs:
  34. - **outputs**: list of output
  35. Example::
  36. >>> net = DataParallelModel(model, device_ids=[0, 1, 2])
  37. >>> output = net(input_var) # input_var can be on any device, including CPU
  38. """
  39. def gather(self, outputs, output_device):
  40. return outputs
  41. def replicate(self, module, device_ids):
  42. modules = super(DataParallelModel, self).replicate(module, device_ids)
  43. return modules
  44. # Reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/parallel.py
  45. class DataParallelCriterion(DataParallel):
  46. """
  47. Calculate loss in multiple-GPUs, which balance the memory usage for
  48. Semantic Segmentation.
  49. The targets are splitted across the specified devices by chunking in
  50. the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
  51. Example::
  52. >>> net = DataParallelModel(model, device_ids=[0, 1, 2])
  53. >>> criterion = DataParallelCriterion(criterion, device_ids=[0, 1, 2])
  54. >>> y = net(x)
  55. >>> loss = criterion(y, target)
  56. """
  57. def forward(self, inputs, *targets, **kwargs):
  58. # the inputs should be the outputs of DataParallelModel
  59. if not self.device_ids:
  60. return self.module(inputs, *targets, **kwargs)
  61. targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
  62. if len(self.device_ids) == 1:
  63. return self.module(inputs, *targets[0], **kwargs[0])
  64. replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
  65. outputs = criterion_parallel_apply(replicas, inputs, targets, kwargs)
  66. return Reduce.apply(*outputs) / len(outputs)
  67. def get_a_var(obj):
  68. if isinstance(obj, torch.Tensor):
  69. return obj
  70. if isinstance(obj, list) or isinstance(obj, tuple):
  71. for result in map(get_a_var, obj):
  72. if isinstance(result, torch.Tensor):
  73. return result
  74. if isinstance(obj, dict):
  75. for result in map(get_a_var, obj.items()):
  76. if isinstance(result, torch.Tensor):
  77. return result
  78. return None
  79. def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
  80. r"""Applies each `module` in :attr:`modules` in parallel on arguments
  81. contained in :attr:`inputs` (positional), attr:'targets' (positional) and :attr:`kwargs_tup` (keyword)
  82. on each of :attr:`devices`.
  83. Args:
  84. modules (Module): modules to be parallelized
  85. inputs (tensor): inputs to the modules
  86. targets (tensor): targets to the modules
  87. devices (list of int or torch.device): CUDA devices
  88. :attr:`modules`, :attr:`inputs`, :attr:'targets' :attr:`kwargs_tup` (if given), and
  89. :attr:`devices` (if given) should all have same length. Moreover, each
  90. element of :attr:`inputs` can either be a single object as the only argument
  91. to a module, or a collection of positional arguments.
  92. """
  93. assert len(modules) == len(inputs)
  94. assert len(targets) == len(inputs)
  95. if kwargs_tup is not None:
  96. assert len(modules) == len(kwargs_tup)
  97. else:
  98. kwargs_tup = ({},) * len(modules)
  99. if devices is not None:
  100. assert len(modules) == len(devices)
  101. else:
  102. devices = [None] * len(modules)
  103. lock = threading.Lock()
  104. results = {}
  105. grad_enabled = torch.is_grad_enabled()
  106. def _worker(i, module, input, target, kwargs, device=None):
  107. torch.set_grad_enabled(grad_enabled)
  108. if device is None:
  109. device = get_a_var(input).get_device()
  110. try:
  111. with torch.cuda.device(device):
  112. output = module(*(list(input) + target), **kwargs)
  113. with lock:
  114. results[i] = output
  115. except Exception as e:
  116. with lock:
  117. results[i] = e
  118. if len(modules) > 1:
  119. threads = [threading.Thread(target=_worker,
  120. args=(i, module, input, target, kwargs, device))
  121. for i, (module, input, target, kwargs, device) in
  122. enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
  123. for thread in threads:
  124. thread.start()
  125. for thread in threads:
  126. thread.join()
  127. else:
  128. _worker(0, modules[0], inputs[0], targets[0], kwargs_tup[0], devices[0])
  129. outputs = []
  130. for i in range(len(inputs)):
  131. output = results[i]
  132. if isinstance(output, Exception):
  133. raise output
  134. outputs.append(output)
  135. return outputs