Selaa lähdekoodia

update common.py add Classify()

5.0
Glenn Jocher 4 vuotta sitten
vanhempi
commit
5387d4747d
1 muutettua tiedostoa jossa 19 lisäystä ja 6 poistoa
  1. +19
    -6
      models/common.py

+ 19
- 6
models/common.py Näytä tiedosto

@@ -76,12 +76,6 @@ class SPP(nn.Module):
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))


class Flatten(nn.Module):
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
def forward(self, x):
return x.view(x.size(0), -1)


class Focus(nn.Module):
# Focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
@@ -100,3 +94,22 @@ class Concat(nn.Module):

def forward(self, x):
return torch.cat(x, self.d)


class Flatten(nn.Module):
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
@staticmethod
def forward(x):
return x.view(x.size(0), -1)


class Classify(nn.Module):
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
super(Classify, self).__init__()
self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) # to x(b,c2,1,1)
self.flat = Flatten()

def forward(self, x):
return self.flat(self.conv(self.aap(x))) # flatten to x(b,c2)

Loading…
Peruuta
Tallenna