|
|
@@ -37,6 +37,7 @@ except ImportError: |
|
|
|
class Detect(nn.Module): |
|
|
|
stride = None # strides computed during build |
|
|
|
onnx_dynamic = False # ONNX export parameter |
|
|
|
export = False # export mode |
|
|
|
|
|
|
|
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer |
|
|
|
super().__init__() |
|
|
@@ -72,7 +73,7 @@ class Detect(nn.Module): |
|
|
|
y = torch.cat((xy, wh, conf), 4) |
|
|
|
z.append(y.view(bs, -1, self.no)) |
|
|
|
|
|
|
|
return x if self.training else (torch.cat(z, 1), x) |
|
|
|
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x) |
|
|
|
|
|
|
|
def _make_grid(self, nx=20, ny=20, i=0): |
|
|
|
d = self.anchors[i].device |