@@ -125,7 +125,7 @@ class ComputeLoss: | |||
# Losses | |||
for i, pi in enumerate(p): # layer index, layer predictions | |||
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx | |||
tobj = torch.zeros(pi.shape[:4], device=self.device) # target obj | |||
tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj | |||
n = b.shape[0] # number of targets | |||
if n: |