torch.from_tensor() bug fix
This commit is contained in:
parent
4fb8cb353f
commit
09402a2174
2
train.py
2
train.py
|
|
@ -225,7 +225,7 @@ def train(hyp, opt, device, tb_writer=None):
|
||||||
if rank != -1:
|
if rank != -1:
|
||||||
indices = torch.zeros([dataset.n], dtype=torch.int)
|
indices = torch.zeros([dataset.n], dtype=torch.int)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
indices[:] = torch.from_tensor(dataset.indices, dtype=torch.int)
|
indices[:] = torch.tensor(dataset.indices, dtype=torch.int)
|
||||||
dist.broadcast(indices, 0)
|
dist.broadcast(indices, 0)
|
||||||
if rank != 0:
|
if rank != 0:
|
||||||
dataset.indices = indices.cpu().numpy()
|
dataset.indices = indices.cpu().numpy()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue