AIlib2/segutils/core/utils/parallel.py

163 lines
5.7 KiB
Python

"""Utils for Semantic Segmentation"""
import threading
import torch
import torch.cuda.comm as comm
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel._functions import Broadcast
from torch.autograd import Function
__all__ = ['DataParallelModel', 'DataParallelCriterion']
class Reduce(Function):
@staticmethod
def forward(ctx, *inputs):
ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
inputs = sorted(inputs, key=lambda i: i.get_device())
return comm.reduce_add(inputs)
@staticmethod
def backward(ctx, gradOutputs):
return Broadcast.apply(ctx.target_gpus, gradOutputs)
class DataParallelModel(DataParallel):
"""Data parallelism
Hide the difference of single/multiple GPUs to the user.
In the forward pass, the module is replicated on each device,
and each replica handles a portion of the input. During the backwards
pass, gradients from each replica are summed into the original module.
The batch size should be larger than the number of GPUs used.
Parameters
----------
module : object
Network to be parallelized.
sync : bool
enable synchronization (default: False).
Inputs:
- **inputs**: list of input
Outputs:
- **outputs**: list of output
Example::
>>> net = DataParallelModel(model, device_ids=[0, 1, 2])
>>> output = net(input_var) # input_var can be on any device, including CPU
"""
def gather(self, outputs, output_device):
return outputs
def replicate(self, module, device_ids):
modules = super(DataParallelModel, self).replicate(module, device_ids)
return modules
# Reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/parallel.py
class DataParallelCriterion(DataParallel):
"""
Calculate loss in multiple-GPUs, which balance the memory usage for
Semantic Segmentation.
The targets are splitted across the specified devices by chunking in
the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
Example::
>>> net = DataParallelModel(model, device_ids=[0, 1, 2])
>>> criterion = DataParallelCriterion(criterion, device_ids=[0, 1, 2])
>>> y = net(x)
>>> loss = criterion(y, target)
"""
def forward(self, inputs, *targets, **kwargs):
# the inputs should be the outputs of DataParallelModel
if not self.device_ids:
return self.module(inputs, *targets, **kwargs)
targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(inputs, *targets[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = criterion_parallel_apply(replicas, inputs, targets, kwargs)
return Reduce.apply(*outputs) / len(outputs)
def get_a_var(obj):
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, list) or isinstance(obj, tuple):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
r"""Applies each `module` in :attr:`modules` in parallel on arguments
contained in :attr:`inputs` (positional), attr:'targets' (positional) and :attr:`kwargs_tup` (keyword)
on each of :attr:`devices`.
Args:
modules (Module): modules to be parallelized
inputs (tensor): inputs to the modules
targets (tensor): targets to the modules
devices (list of int or torch.device): CUDA devices
:attr:`modules`, :attr:`inputs`, :attr:'targets' :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(inputs)
assert len(targets) == len(inputs)
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, target, kwargs, device=None):
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
output = module(*(list(input) + target), **kwargs)
with lock:
results[i] = output
except Exception as e:
with lock:
results[i] = e
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, target, kwargs, device))
for i, (module, input, target, kwargs, device) in
enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], targets[0], kwargs_tup[0], devices[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs