|
- """Utils for Semantic Segmentation"""
- import threading
- import torch
- import torch.cuda.comm as comm
- from torch.nn.parallel.data_parallel import DataParallel
- from torch.nn.parallel._functions import Broadcast
- from torch.autograd import Function
-
- __all__ = ['DataParallelModel', 'DataParallelCriterion']
-
-
- class Reduce(Function):
- @staticmethod
- def forward(ctx, *inputs):
- ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
- inputs = sorted(inputs, key=lambda i: i.get_device())
- return comm.reduce_add(inputs)
-
- @staticmethod
- def backward(ctx, gradOutputs):
- return Broadcast.apply(ctx.target_gpus, gradOutputs)
-
-
- class DataParallelModel(DataParallel):
- """Data parallelism
-
- Hide the difference of single/multiple GPUs to the user.
- In the forward pass, the module is replicated on each device,
- and each replica handles a portion of the input. During the backwards
- pass, gradients from each replica are summed into the original module.
-
- The batch size should be larger than the number of GPUs used.
-
- Parameters
- ----------
- module : object
- Network to be parallelized.
- sync : bool
- enable synchronization (default: False).
- Inputs:
- - **inputs**: list of input
- Outputs:
- - **outputs**: list of output
- Example::
- >>> net = DataParallelModel(model, device_ids=[0, 1, 2])
- >>> output = net(input_var) # input_var can be on any device, including CPU
- """
-
- def gather(self, outputs, output_device):
- return outputs
-
- def replicate(self, module, device_ids):
- modules = super(DataParallelModel, self).replicate(module, device_ids)
- return modules
-
-
- # Reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/parallel.py
- class DataParallelCriterion(DataParallel):
- """
- Calculate loss in multiple-GPUs, which balance the memory usage for
- Semantic Segmentation.
-
- The targets are splitted across the specified devices by chunking in
- the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
-
- Example::
- >>> net = DataParallelModel(model, device_ids=[0, 1, 2])
- >>> criterion = DataParallelCriterion(criterion, device_ids=[0, 1, 2])
- >>> y = net(x)
- >>> loss = criterion(y, target)
- """
-
- def forward(self, inputs, *targets, **kwargs):
- # the inputs should be the outputs of DataParallelModel
- if not self.device_ids:
- return self.module(inputs, *targets, **kwargs)
- targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
- if len(self.device_ids) == 1:
- return self.module(inputs, *targets[0], **kwargs[0])
- replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
- outputs = criterion_parallel_apply(replicas, inputs, targets, kwargs)
- return Reduce.apply(*outputs) / len(outputs)
-
-
- def get_a_var(obj):
- if isinstance(obj, torch.Tensor):
- return obj
-
- if isinstance(obj, list) or isinstance(obj, tuple):
- for result in map(get_a_var, obj):
- if isinstance(result, torch.Tensor):
- return result
-
- if isinstance(obj, dict):
- for result in map(get_a_var, obj.items()):
- if isinstance(result, torch.Tensor):
- return result
- return None
-
-
- def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
- r"""Applies each `module` in :attr:`modules` in parallel on arguments
- contained in :attr:`inputs` (positional), attr:'targets' (positional) and :attr:`kwargs_tup` (keyword)
- on each of :attr:`devices`.
-
- Args:
- modules (Module): modules to be parallelized
- inputs (tensor): inputs to the modules
- targets (tensor): targets to the modules
- devices (list of int or torch.device): CUDA devices
- :attr:`modules`, :attr:`inputs`, :attr:'targets' :attr:`kwargs_tup` (if given), and
- :attr:`devices` (if given) should all have same length. Moreover, each
- element of :attr:`inputs` can either be a single object as the only argument
- to a module, or a collection of positional arguments.
- """
- assert len(modules) == len(inputs)
- assert len(targets) == len(inputs)
- if kwargs_tup is not None:
- assert len(modules) == len(kwargs_tup)
- else:
- kwargs_tup = ({},) * len(modules)
- if devices is not None:
- assert len(modules) == len(devices)
- else:
- devices = [None] * len(modules)
- lock = threading.Lock()
- results = {}
- grad_enabled = torch.is_grad_enabled()
-
- def _worker(i, module, input, target, kwargs, device=None):
- torch.set_grad_enabled(grad_enabled)
- if device is None:
- device = get_a_var(input).get_device()
- try:
- with torch.cuda.device(device):
- output = module(*(list(input) + target), **kwargs)
- with lock:
- results[i] = output
- except Exception as e:
- with lock:
- results[i] = e
-
- if len(modules) > 1:
- threads = [threading.Thread(target=_worker,
- args=(i, module, input, target, kwargs, device))
- for i, (module, input, target, kwargs, device) in
- enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
-
- for thread in threads:
- thread.start()
- for thread in threads:
- thread.join()
- else:
- _worker(0, modules[0], inputs[0], targets[0], kwargs_tup[0], devices[0])
-
- outputs = []
- for i in range(len(inputs)):
- output = results[i]
- if isinstance(output, Exception):
- raise output
- outputs.append(output)
- return outputs
|