@@ -112,4 +112,5 @@ class Classify(nn.Module): | |||
self.flat = Flatten() | |||
def forward(self, x): | |||
return self.flat(self.conv(self.aap(x))) # flatten to x(b,c2) | |||
z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list | |||
return self.flat(self.conv(z)) # flatten to x(b,c2) |