You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

259 lines
8.5KB

  1. """
  2. This file contains primitives for multi-gpu communication.
  3. This is useful when doing distributed training.
  4. """
  5. import math
  6. import pickle
  7. import torch
  8. import torch.utils.data as data
  9. import torch.distributed as dist
  10. from torch.utils.data.sampler import Sampler, BatchSampler
  11. __all__ = ['get_world_size', 'get_rank', 'synchronize', 'is_main_process',
  12. 'all_gather', 'make_data_sampler', 'make_batch_data_sampler',
  13. 'reduce_dict', 'reduce_loss_dict']
  14. # reference: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/comm.py
  15. def get_world_size():
  16. if not dist.is_available():
  17. return 1
  18. if not dist.is_initialized():
  19. return 1
  20. return dist.get_world_size()
  21. def get_rank():
  22. if not dist.is_available():
  23. return 0
  24. if not dist.is_initialized():
  25. return 0
  26. return dist.get_rank()
  27. def is_main_process():
  28. return get_rank() == 0
  29. def synchronize():
  30. """
  31. Helper function to synchronize (barrier) among all processes when
  32. using distributed training
  33. """
  34. if not dist.is_available():
  35. return
  36. if not dist.is_initialized():
  37. return
  38. world_size = dist.get_world_size()
  39. if world_size == 1:
  40. return
  41. dist.barrier()
  42. def all_gather(data):
  43. """
  44. Run all_gather on arbitrary picklable data (not necessarily tensors)
  45. Args:
  46. data: any picklable object
  47. Returns:
  48. list[data]: list of data gathered from each rank
  49. """
  50. world_size = get_world_size()
  51. if world_size == 1:
  52. return [data]
  53. # serialized to a Tensor
  54. buffer = pickle.dumps(data)
  55. storage = torch.ByteStorage.from_buffer(buffer)
  56. tensor = torch.ByteTensor(storage).to("cuda")
  57. # obtain Tensor size of each rank
  58. local_size = torch.IntTensor([tensor.numel()]).to("cuda")
  59. size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
  60. dist.all_gather(size_list, local_size)
  61. size_list = [int(size.item()) for size in size_list]
  62. max_size = max(size_list)
  63. # receiving Tensor from all ranks
  64. # we pad the tensor because torch all_gather does not support
  65. # gathering tensors of different shapes
  66. tensor_list = []
  67. for _ in size_list:
  68. tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
  69. if local_size != max_size:
  70. padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
  71. tensor = torch.cat((tensor, padding), dim=0)
  72. dist.all_gather(tensor_list, tensor)
  73. data_list = []
  74. for size, tensor in zip(size_list, tensor_list):
  75. buffer = tensor.cpu().numpy().tobytes()[:size]
  76. data_list.append(pickle.loads(buffer))
  77. return data_list
  78. def reduce_dict(input_dict, average=True):
  79. """
  80. Args:
  81. input_dict (dict): all the values will be reduced
  82. average (bool): whether to do average or sum
  83. Reduce the values in the dictionary from all processes so that process with rank
  84. 0 has the averaged results. Returns a dict with the same fields as
  85. input_dict, after reduction.
  86. """
  87. world_size = get_world_size()
  88. if world_size < 2:
  89. return input_dict
  90. with torch.no_grad():
  91. names = []
  92. values = []
  93. # sort the keys so that they are consistent across processes
  94. for k in sorted(input_dict.keys()):
  95. names.append(k)
  96. values.append(input_dict[k])
  97. values = torch.stack(values, dim=0)
  98. dist.reduce(values, dst=0)
  99. if dist.get_rank() == 0 and average:
  100. # only main process gets accumulated, so only divide by
  101. # world_size in this case
  102. values /= world_size
  103. reduced_dict = {k: v for k, v in zip(names, values)}
  104. return reduced_dict
  105. def reduce_loss_dict(loss_dict):
  106. """
  107. Reduce the loss dictionary from all processes so that process with rank
  108. 0 has the averaged results. Returns a dict with the same fields as
  109. loss_dict, after reduction.
  110. """
  111. world_size = get_world_size()
  112. if world_size < 2:
  113. return loss_dict
  114. with torch.no_grad():
  115. loss_names = []
  116. all_losses = []
  117. for k in sorted(loss_dict.keys()):
  118. loss_names.append(k)
  119. all_losses.append(loss_dict[k])
  120. all_losses = torch.stack(all_losses, dim=0)
  121. dist.reduce(all_losses, dst=0)
  122. if dist.get_rank() == 0:
  123. # only main process gets accumulated, so only divide by
  124. # world_size in this case
  125. all_losses /= world_size
  126. reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
  127. return reduced_losses
  128. def make_data_sampler(dataset, shuffle, distributed):
  129. if distributed:
  130. return DistributedSampler(dataset, shuffle=shuffle)
  131. if shuffle:
  132. sampler = data.sampler.RandomSampler(dataset)
  133. else:
  134. sampler = data.sampler.SequentialSampler(dataset)
  135. return sampler
  136. def make_batch_data_sampler(sampler, images_per_batch, num_iters=None, start_iter=0):
  137. batch_sampler = data.sampler.BatchSampler(sampler, images_per_batch, drop_last=True)
  138. if num_iters is not None:
  139. batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iters, start_iter)
  140. return batch_sampler
  141. # Code is copy-pasted from https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/data/samplers/distributed.py
  142. class DistributedSampler(Sampler):
  143. """Sampler that restricts data loading to a subset of the dataset.
  144. It is especially useful in conjunction with
  145. :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
  146. process can pass a DistributedSampler instance as a DataLoader sampler,
  147. and load a subset of the original dataset that is exclusive to it.
  148. .. note::
  149. Dataset is assumed to be of constant size.
  150. Arguments:
  151. dataset: Dataset used for sampling.
  152. num_replicas (optional): Number of processes participating in
  153. distributed training.
  154. rank (optional): Rank of the current process within num_replicas.
  155. """
  156. def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
  157. if num_replicas is None:
  158. if not dist.is_available():
  159. raise RuntimeError("Requires distributed package to be available")
  160. num_replicas = dist.get_world_size()
  161. if rank is None:
  162. if not dist.is_available():
  163. raise RuntimeError("Requires distributed package to be available")
  164. rank = dist.get_rank()
  165. self.dataset = dataset
  166. self.num_replicas = num_replicas
  167. self.rank = rank
  168. self.epoch = 0
  169. self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
  170. self.total_size = self.num_samples * self.num_replicas
  171. self.shuffle = shuffle
  172. def __iter__(self):
  173. if self.shuffle:
  174. # deterministically shuffle based on epoch
  175. g = torch.Generator()
  176. g.manual_seed(self.epoch)
  177. indices = torch.randperm(len(self.dataset), generator=g).tolist()
  178. else:
  179. indices = torch.arange(len(self.dataset)).tolist()
  180. # add extra samples to make it evenly divisible
  181. indices += indices[: (self.total_size - len(indices))]
  182. assert len(indices) == self.total_size
  183. # subsample
  184. offset = self.num_samples * self.rank
  185. indices = indices[offset: offset + self.num_samples]
  186. assert len(indices) == self.num_samples
  187. return iter(indices)
  188. def __len__(self):
  189. return self.num_samples
  190. def set_epoch(self, epoch):
  191. self.epoch = epoch
  192. class IterationBasedBatchSampler(BatchSampler):
  193. """
  194. Wraps a BatchSampler, resampling from it until
  195. a specified number of iterations have been sampled
  196. """
  197. def __init__(self, batch_sampler, num_iterations, start_iter=0):
  198. self.batch_sampler = batch_sampler
  199. self.num_iterations = num_iterations
  200. self.start_iter = start_iter
  201. def __iter__(self):
  202. iteration = self.start_iter
  203. while iteration <= self.num_iterations:
  204. # if the underlying sampler has a set_epoch method, like
  205. # DistributedSampler, used for making each process see
  206. # a different split of the dataset, then set it
  207. if hasattr(self.batch_sampler.sampler, "set_epoch"):
  208. self.batch_sampler.sampler.set_epoch(iteration)
  209. for batch in self.batch_sampler:
  210. iteration += 1
  211. if iteration > self.num_iterations:
  212. break
  213. yield batch
  214. def __len__(self):
  215. return self.num_iterations
  216. if __name__ == '__main__':
  217. pass