if rank != -1: | if rank != -1: | ||||
indices = torch.zeros([dataset.n], dtype=torch.int) | indices = torch.zeros([dataset.n], dtype=torch.int) | ||||
if rank == 0: | if rank == 0: | ||||
indices[:] = torch.from_tensor(dataset.indices, dtype=torch.int) | |||||
indices[:] = torch.tensor(dataset.indices, dtype=torch.int) | |||||
dist.broadcast(indices, 0) | dist.broadcast(indices, 0) | ||||
if rank != 0: | if rank != 0: | ||||
dataset.indices = indices.cpu().numpy() | dataset.indices = indices.cpu().numpy() |