Browse Source

update common.py Classify()

5.0
Glenn Jocher 4 years ago
parent
commit
a97c3f94ec
1 changed files with 2 additions and 1 deletions
  1. +2
    -1
      models/common.py

+ 2
- 1
models/common.py View File

@@ -112,4 +112,5 @@ class Classify(nn.Module):
self.flat = Flatten()

def forward(self, x):
return self.flat(self.conv(self.aap(x))) # flatten to x(b,c2)
z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
return self.flat(self.conv(z)) # flatten to x(b,c2)

Loading…
Cancel
Save