|
- """
- This file contains primitives for multi-gpu communication.
- This is useful when doing distributed training.
- """
- import math
- import pickle
- import torch
- import torch.utils.data as data
- import torch.distributed as dist
-
- from torch.utils.data.sampler import Sampler, BatchSampler
-
- __all__ = ['get_world_size', 'get_rank', 'synchronize', 'is_main_process',
- 'all_gather', 'make_data_sampler', 'make_batch_data_sampler',
- 'reduce_dict', 'reduce_loss_dict']
-
-
- # reference: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/comm.py
- def get_world_size():
- if not dist.is_available():
- return 1
- if not dist.is_initialized():
- return 1
- return dist.get_world_size()
-
-
- def get_rank():
- if not dist.is_available():
- return 0
- if not dist.is_initialized():
- return 0
- return dist.get_rank()
-
-
- def is_main_process():
- return get_rank() == 0
-
-
- def synchronize():
- """
- Helper function to synchronize (barrier) among all processes when
- using distributed training
- """
- if not dist.is_available():
- return
- if not dist.is_initialized():
- return
- world_size = dist.get_world_size()
- if world_size == 1:
- return
- dist.barrier()
-
-
- def all_gather(data):
- """
- Run all_gather on arbitrary picklable data (not necessarily tensors)
- Args:
- data: any picklable object
- Returns:
- list[data]: list of data gathered from each rank
- """
- world_size = get_world_size()
- if world_size == 1:
- return [data]
-
- # serialized to a Tensor
- buffer = pickle.dumps(data)
- storage = torch.ByteStorage.from_buffer(buffer)
- tensor = torch.ByteTensor(storage).to("cuda")
-
- # obtain Tensor size of each rank
- local_size = torch.IntTensor([tensor.numel()]).to("cuda")
- size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
- dist.all_gather(size_list, local_size)
- size_list = [int(size.item()) for size in size_list]
- max_size = max(size_list)
-
- # receiving Tensor from all ranks
- # we pad the tensor because torch all_gather does not support
- # gathering tensors of different shapes
- tensor_list = []
- for _ in size_list:
- tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
- if local_size != max_size:
- padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
- tensor = torch.cat((tensor, padding), dim=0)
- dist.all_gather(tensor_list, tensor)
-
- data_list = []
- for size, tensor in zip(size_list, tensor_list):
- buffer = tensor.cpu().numpy().tobytes()[:size]
- data_list.append(pickle.loads(buffer))
-
- return data_list
-
-
- def reduce_dict(input_dict, average=True):
- """
- Args:
- input_dict (dict): all the values will be reduced
- average (bool): whether to do average or sum
- Reduce the values in the dictionary from all processes so that process with rank
- 0 has the averaged results. Returns a dict with the same fields as
- input_dict, after reduction.
- """
- world_size = get_world_size()
- if world_size < 2:
- return input_dict
- with torch.no_grad():
- names = []
- values = []
- # sort the keys so that they are consistent across processes
- for k in sorted(input_dict.keys()):
- names.append(k)
- values.append(input_dict[k])
- values = torch.stack(values, dim=0)
- dist.reduce(values, dst=0)
- if dist.get_rank() == 0 and average:
- # only main process gets accumulated, so only divide by
- # world_size in this case
- values /= world_size
- reduced_dict = {k: v for k, v in zip(names, values)}
- return reduced_dict
-
-
- def reduce_loss_dict(loss_dict):
- """
- Reduce the loss dictionary from all processes so that process with rank
- 0 has the averaged results. Returns a dict with the same fields as
- loss_dict, after reduction.
- """
- world_size = get_world_size()
- if world_size < 2:
- return loss_dict
- with torch.no_grad():
- loss_names = []
- all_losses = []
- for k in sorted(loss_dict.keys()):
- loss_names.append(k)
- all_losses.append(loss_dict[k])
- all_losses = torch.stack(all_losses, dim=0)
- dist.reduce(all_losses, dst=0)
- if dist.get_rank() == 0:
- # only main process gets accumulated, so only divide by
- # world_size in this case
- all_losses /= world_size
- reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
- return reduced_losses
-
-
- def make_data_sampler(dataset, shuffle, distributed):
- if distributed:
- return DistributedSampler(dataset, shuffle=shuffle)
- if shuffle:
- sampler = data.sampler.RandomSampler(dataset)
- else:
- sampler = data.sampler.SequentialSampler(dataset)
- return sampler
-
-
- def make_batch_data_sampler(sampler, images_per_batch, num_iters=None, start_iter=0):
- batch_sampler = data.sampler.BatchSampler(sampler, images_per_batch, drop_last=True)
- if num_iters is not None:
- batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iters, start_iter)
- return batch_sampler
-
-
- # Code is copy-pasted from https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/data/samplers/distributed.py
- class DistributedSampler(Sampler):
- """Sampler that restricts data loading to a subset of the dataset.
- It is especially useful in conjunction with
- :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
- process can pass a DistributedSampler instance as a DataLoader sampler,
- and load a subset of the original dataset that is exclusive to it.
- .. note::
- Dataset is assumed to be of constant size.
- Arguments:
- dataset: Dataset used for sampling.
- num_replicas (optional): Number of processes participating in
- distributed training.
- rank (optional): Rank of the current process within num_replicas.
- """
-
- def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
- if num_replicas is None:
- if not dist.is_available():
- raise RuntimeError("Requires distributed package to be available")
- num_replicas = dist.get_world_size()
- if rank is None:
- if not dist.is_available():
- raise RuntimeError("Requires distributed package to be available")
- rank = dist.get_rank()
- self.dataset = dataset
- self.num_replicas = num_replicas
- self.rank = rank
- self.epoch = 0
- self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
- self.total_size = self.num_samples * self.num_replicas
- self.shuffle = shuffle
-
- def __iter__(self):
- if self.shuffle:
- # deterministically shuffle based on epoch
- g = torch.Generator()
- g.manual_seed(self.epoch)
- indices = torch.randperm(len(self.dataset), generator=g).tolist()
- else:
- indices = torch.arange(len(self.dataset)).tolist()
-
- # add extra samples to make it evenly divisible
- indices += indices[: (self.total_size - len(indices))]
- assert len(indices) == self.total_size
-
- # subsample
- offset = self.num_samples * self.rank
- indices = indices[offset: offset + self.num_samples]
- assert len(indices) == self.num_samples
-
- return iter(indices)
-
- def __len__(self):
- return self.num_samples
-
- def set_epoch(self, epoch):
- self.epoch = epoch
-
-
- class IterationBasedBatchSampler(BatchSampler):
- """
- Wraps a BatchSampler, resampling from it until
- a specified number of iterations have been sampled
- """
-
- def __init__(self, batch_sampler, num_iterations, start_iter=0):
- self.batch_sampler = batch_sampler
- self.num_iterations = num_iterations
- self.start_iter = start_iter
-
- def __iter__(self):
- iteration = self.start_iter
- while iteration <= self.num_iterations:
- # if the underlying sampler has a set_epoch method, like
- # DistributedSampler, used for making each process see
- # a different split of the dataset, then set it
- if hasattr(self.batch_sampler.sampler, "set_epoch"):
- self.batch_sampler.sampler.set_epoch(iteration)
- for batch in self.batch_sampler:
- iteration += 1
- if iteration > self.num_iterations:
- break
- yield batch
-
- def __len__(self):
- return self.num_iterations
-
-
- if __name__ == '__main__':
- pass
|