""" This file contains primitives for multi-gpu communication. This is useful when doing distributed training. """ import math import pickle import torch import torch.utils.data as data import torch.distributed as dist from torch.utils.data.sampler import Sampler, BatchSampler __all__ = ['get_world_size', 'get_rank', 'synchronize', 'is_main_process', 'all_gather', 'make_data_sampler', 'make_batch_data_sampler', 'reduce_dict', 'reduce_loss_dict'] # reference: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/comm.py def get_world_size(): if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size() def get_rank(): if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def synchronize(): """ Helper function to synchronize (barrier) among all processes when using distributed training """ if not dist.is_available(): return if not dist.is_initialized(): return world_size = dist.get_world_size() if world_size == 1: return dist.barrier() def all_gather(data): """ Run all_gather on arbitrary picklable data (not necessarily tensors) Args: data: any picklable object Returns: list[data]: list of data gathered from each rank """ world_size = get_world_size() if world_size == 1: return [data] # serialized to a Tensor buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to("cuda") # obtain Tensor size of each rank local_size = torch.IntTensor([tensor.numel()]).to("cuda") size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # receiving Tensor from all ranks # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes tensor_list = [] for _ in size_list: tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) if local_size != max_size: padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") tensor = torch.cat((tensor, padding), dim=0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def reduce_dict(input_dict, average=True): """ Args: input_dict (dict): all the values will be reduced average (bool): whether to do average or sum Reduce the values in the dictionary from all processes so that process with rank 0 has the averaged results. Returns a dict with the same fields as input_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return input_dict with torch.no_grad(): names = [] values = [] # sort the keys so that they are consistent across processes for k in sorted(input_dict.keys()): names.append(k) values.append(input_dict[k]) values = torch.stack(values, dim=0) dist.reduce(values, dst=0) if dist.get_rank() == 0 and average: # only main process gets accumulated, so only divide by # world_size in this case values /= world_size reduced_dict = {k: v for k, v in zip(names, values)} return reduced_dict def reduce_loss_dict(loss_dict): """ Reduce the loss dictionary from all processes so that process with rank 0 has the averaged results. Returns a dict with the same fields as loss_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return loss_dict with torch.no_grad(): loss_names = [] all_losses = [] for k in sorted(loss_dict.keys()): loss_names.append(k) all_losses.append(loss_dict[k]) all_losses = torch.stack(all_losses, dim=0) dist.reduce(all_losses, dst=0) if dist.get_rank() == 0: # only main process gets accumulated, so only divide by # world_size in this case all_losses /= world_size reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} return reduced_losses def make_data_sampler(dataset, shuffle, distributed): if distributed: return DistributedSampler(dataset, shuffle=shuffle) if shuffle: sampler = data.sampler.RandomSampler(dataset) else: sampler = data.sampler.SequentialSampler(dataset) return sampler def make_batch_data_sampler(sampler, images_per_batch, num_iters=None, start_iter=0): batch_sampler = data.sampler.BatchSampler(sampler, images_per_batch, drop_last=True) if num_iters is not None: batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iters, start_iter) return batch_sampler # Code is copy-pasted from https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/data/samplers/distributed.py class DistributedSampler(Sampler): """Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each process can pass a DistributedSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it. .. note:: Dataset is assumed to be of constant size. Arguments: dataset: Dataset used for sampling. num_replicas (optional): Number of processes participating in distributed training. rank (optional): Rank of the current process within num_replicas. """ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle def __iter__(self): if self.shuffle: # deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = torch.arange(len(self.dataset)).tolist() # add extra samples to make it evenly divisible indices += indices[: (self.total_size - len(indices))] assert len(indices) == self.total_size # subsample offset = self.num_samples * self.rank indices = indices[offset: offset + self.num_samples] assert len(indices) == self.num_samples return iter(indices) def __len__(self): return self.num_samples def set_epoch(self, epoch): self.epoch = epoch class IterationBasedBatchSampler(BatchSampler): """ Wraps a BatchSampler, resampling from it until a specified number of iterations have been sampled """ def __init__(self, batch_sampler, num_iterations, start_iter=0): self.batch_sampler = batch_sampler self.num_iterations = num_iterations self.start_iter = start_iter def __iter__(self): iteration = self.start_iter while iteration <= self.num_iterations: # if the underlying sampler has a set_epoch method, like # DistributedSampler, used for making each process see # a different split of the dataset, then set it if hasattr(self.batch_sampler.sampler, "set_epoch"): self.batch_sampler.sampler.set_epoch(iteration) for batch in self.batch_sampler: iteration += 1 if iteration > self.num_iterations: break yield batch def __len__(self): return self.num_iterations if __name__ == '__main__': pass