"""Utils for Semantic Segmentation""" import threading import torch import torch.cuda.comm as comm from torch.nn.parallel.data_parallel import DataParallel from torch.nn.parallel._functions import Broadcast from torch.autograd import Function __all__ = ['DataParallelModel', 'DataParallelCriterion'] class Reduce(Function): @staticmethod def forward(ctx, *inputs): ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] inputs = sorted(inputs, key=lambda i: i.get_device()) return comm.reduce_add(inputs) @staticmethod def backward(ctx, gradOutputs): return Broadcast.apply(ctx.target_gpus, gradOutputs) class DataParallelModel(DataParallel): """Data parallelism Hide the difference of single/multiple GPUs to the user. In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. The batch size should be larger than the number of GPUs used. Parameters ---------- module : object Network to be parallelized. sync : bool enable synchronization (default: False). Inputs: - **inputs**: list of input Outputs: - **outputs**: list of output Example:: >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) >>> output = net(input_var) # input_var can be on any device, including CPU """ def gather(self, outputs, output_device): return outputs def replicate(self, module, device_ids): modules = super(DataParallelModel, self).replicate(module, device_ids) return modules # Reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/parallel.py class DataParallelCriterion(DataParallel): """ Calculate loss in multiple-GPUs, which balance the memory usage for Semantic Segmentation. The targets are splitted across the specified devices by chunking in the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. Example:: >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) >>> criterion = DataParallelCriterion(criterion, device_ids=[0, 1, 2]) >>> y = net(x) >>> loss = criterion(y, target) """ def forward(self, inputs, *targets, **kwargs): # the inputs should be the outputs of DataParallelModel if not self.device_ids: return self.module(inputs, *targets, **kwargs) targets, kwargs = self.scatter(targets, kwargs, self.device_ids) if len(self.device_ids) == 1: return self.module(inputs, *targets[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = criterion_parallel_apply(replicas, inputs, targets, kwargs) return Reduce.apply(*outputs) / len(outputs) def get_a_var(obj): if isinstance(obj, torch.Tensor): return obj if isinstance(obj, list) or isinstance(obj, tuple): for result in map(get_a_var, obj): if isinstance(result, torch.Tensor): return result if isinstance(obj, dict): for result in map(get_a_var, obj.items()): if isinstance(result, torch.Tensor): return result return None def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): r"""Applies each `module` in :attr:`modules` in parallel on arguments contained in :attr:`inputs` (positional), attr:'targets' (positional) and :attr:`kwargs_tup` (keyword) on each of :attr:`devices`. Args: modules (Module): modules to be parallelized inputs (tensor): inputs to the modules targets (tensor): targets to the modules devices (list of int or torch.device): CUDA devices :attr:`modules`, :attr:`inputs`, :attr:'targets' :attr:`kwargs_tup` (if given), and :attr:`devices` (if given) should all have same length. Moreover, each element of :attr:`inputs` can either be a single object as the only argument to a module, or a collection of positional arguments. """ assert len(modules) == len(inputs) assert len(targets) == len(inputs) if kwargs_tup is not None: assert len(modules) == len(kwargs_tup) else: kwargs_tup = ({},) * len(modules) if devices is not None: assert len(modules) == len(devices) else: devices = [None] * len(modules) lock = threading.Lock() results = {} grad_enabled = torch.is_grad_enabled() def _worker(i, module, input, target, kwargs, device=None): torch.set_grad_enabled(grad_enabled) if device is None: device = get_a_var(input).get_device() try: with torch.cuda.device(device): output = module(*(list(input) + target), **kwargs) with lock: results[i] = output except Exception as e: with lock: results[i] = e if len(modules) > 1: threads = [threading.Thread(target=_worker, args=(i, module, input, target, kwargs, device)) for i, (module, input, target, kwargs, device) in enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] for thread in threads: thread.start() for thread in threads: thread.join() else: _worker(0, modules[0], inputs[0], targets[0], kwargs_tup[0], devices[0]) outputs = [] for i in range(len(inputs)): output = results[i] if isinstance(output, Exception): raise output outputs.append(output) return outputs