|
|
@@ -72,9 +72,9 @@ class Detect(nn.Module): |
|
|
|
def _make_grid(self, nx=20, ny=20, i=0): |
|
|
|
d = self.anchors[i].device |
|
|
|
if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility |
|
|
|
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)], indexing='ij') |
|
|
|
yv, xv = torch.meshgrid([torch.arange(ny, device=d), torch.arange(nx, device=d)], indexing='ij') |
|
|
|
else: |
|
|
|
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)]) |
|
|
|
yv, xv = torch.meshgrid([torch.arange(ny, device=d), torch.arange(nx, device=d)]) |
|
|
|
grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float() |
|
|
|
anchor_grid = (self.anchors[i].clone() * self.stride[i]) \ |
|
|
|
.view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float() |