163 lines
5.7 KiB
Python
163 lines
5.7 KiB
Python
"""Utils for Semantic Segmentation"""
|
|
import threading
|
|
import torch
|
|
import torch.cuda.comm as comm
|
|
from torch.nn.parallel.data_parallel import DataParallel
|
|
from torch.nn.parallel._functions import Broadcast
|
|
from torch.autograd import Function
|
|
|
|
__all__ = ['DataParallelModel', 'DataParallelCriterion']
|
|
|
|
|
|
class Reduce(Function):
|
|
@staticmethod
|
|
def forward(ctx, *inputs):
|
|
ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
|
|
inputs = sorted(inputs, key=lambda i: i.get_device())
|
|
return comm.reduce_add(inputs)
|
|
|
|
@staticmethod
|
|
def backward(ctx, gradOutputs):
|
|
return Broadcast.apply(ctx.target_gpus, gradOutputs)
|
|
|
|
|
|
class DataParallelModel(DataParallel):
|
|
"""Data parallelism
|
|
|
|
Hide the difference of single/multiple GPUs to the user.
|
|
In the forward pass, the module is replicated on each device,
|
|
and each replica handles a portion of the input. During the backwards
|
|
pass, gradients from each replica are summed into the original module.
|
|
|
|
The batch size should be larger than the number of GPUs used.
|
|
|
|
Parameters
|
|
----------
|
|
module : object
|
|
Network to be parallelized.
|
|
sync : bool
|
|
enable synchronization (default: False).
|
|
Inputs:
|
|
- **inputs**: list of input
|
|
Outputs:
|
|
- **outputs**: list of output
|
|
Example::
|
|
>>> net = DataParallelModel(model, device_ids=[0, 1, 2])
|
|
>>> output = net(input_var) # input_var can be on any device, including CPU
|
|
"""
|
|
|
|
def gather(self, outputs, output_device):
|
|
return outputs
|
|
|
|
def replicate(self, module, device_ids):
|
|
modules = super(DataParallelModel, self).replicate(module, device_ids)
|
|
return modules
|
|
|
|
|
|
# Reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/parallel.py
|
|
class DataParallelCriterion(DataParallel):
|
|
"""
|
|
Calculate loss in multiple-GPUs, which balance the memory usage for
|
|
Semantic Segmentation.
|
|
|
|
The targets are splitted across the specified devices by chunking in
|
|
the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
|
|
|
|
Example::
|
|
>>> net = DataParallelModel(model, device_ids=[0, 1, 2])
|
|
>>> criterion = DataParallelCriterion(criterion, device_ids=[0, 1, 2])
|
|
>>> y = net(x)
|
|
>>> loss = criterion(y, target)
|
|
"""
|
|
|
|
def forward(self, inputs, *targets, **kwargs):
|
|
# the inputs should be the outputs of DataParallelModel
|
|
if not self.device_ids:
|
|
return self.module(inputs, *targets, **kwargs)
|
|
targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
|
|
if len(self.device_ids) == 1:
|
|
return self.module(inputs, *targets[0], **kwargs[0])
|
|
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
|
outputs = criterion_parallel_apply(replicas, inputs, targets, kwargs)
|
|
return Reduce.apply(*outputs) / len(outputs)
|
|
|
|
|
|
def get_a_var(obj):
|
|
if isinstance(obj, torch.Tensor):
|
|
return obj
|
|
|
|
if isinstance(obj, list) or isinstance(obj, tuple):
|
|
for result in map(get_a_var, obj):
|
|
if isinstance(result, torch.Tensor):
|
|
return result
|
|
|
|
if isinstance(obj, dict):
|
|
for result in map(get_a_var, obj.items()):
|
|
if isinstance(result, torch.Tensor):
|
|
return result
|
|
return None
|
|
|
|
|
|
def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
|
|
r"""Applies each `module` in :attr:`modules` in parallel on arguments
|
|
contained in :attr:`inputs` (positional), attr:'targets' (positional) and :attr:`kwargs_tup` (keyword)
|
|
on each of :attr:`devices`.
|
|
|
|
Args:
|
|
modules (Module): modules to be parallelized
|
|
inputs (tensor): inputs to the modules
|
|
targets (tensor): targets to the modules
|
|
devices (list of int or torch.device): CUDA devices
|
|
:attr:`modules`, :attr:`inputs`, :attr:'targets' :attr:`kwargs_tup` (if given), and
|
|
:attr:`devices` (if given) should all have same length. Moreover, each
|
|
element of :attr:`inputs` can either be a single object as the only argument
|
|
to a module, or a collection of positional arguments.
|
|
"""
|
|
assert len(modules) == len(inputs)
|
|
assert len(targets) == len(inputs)
|
|
if kwargs_tup is not None:
|
|
assert len(modules) == len(kwargs_tup)
|
|
else:
|
|
kwargs_tup = ({},) * len(modules)
|
|
if devices is not None:
|
|
assert len(modules) == len(devices)
|
|
else:
|
|
devices = [None] * len(modules)
|
|
lock = threading.Lock()
|
|
results = {}
|
|
grad_enabled = torch.is_grad_enabled()
|
|
|
|
def _worker(i, module, input, target, kwargs, device=None):
|
|
torch.set_grad_enabled(grad_enabled)
|
|
if device is None:
|
|
device = get_a_var(input).get_device()
|
|
try:
|
|
with torch.cuda.device(device):
|
|
output = module(*(list(input) + target), **kwargs)
|
|
with lock:
|
|
results[i] = output
|
|
except Exception as e:
|
|
with lock:
|
|
results[i] = e
|
|
|
|
if len(modules) > 1:
|
|
threads = [threading.Thread(target=_worker,
|
|
args=(i, module, input, target, kwargs, device))
|
|
for i, (module, input, target, kwargs, device) in
|
|
enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
else:
|
|
_worker(0, modules[0], inputs[0], targets[0], kwargs_tup[0], devices[0])
|
|
|
|
outputs = []
|
|
for i in range(len(inputs)):
|
|
output = results[i]
|
|
if isinstance(output, Exception):
|
|
raise output
|
|
outputs.append(output)
|
|
return outputs
|