diff --git a/train.py b/train.py index 256560b..52ea927 100644 --- a/train.py +++ b/train.py @@ -225,7 +225,7 @@ def train(hyp, opt, device, tb_writer=None): if rank != -1: indices = torch.zeros([dataset.n], dtype=torch.int) if rank == 0: - indices[:] = torch.from_tensor(dataset.indices, dtype=torch.int) + indices[:] = torch.tensor(dataset.indices, dtype=torch.int) dist.broadcast(indices, 0) if rank != 0: dataset.indices = indices.cpu().numpy()