|
|
|
|
|
|
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
super(Ensemble, self).__init__() |
|
|
super(Ensemble, self).__init__() |
|
|
|
|
|
|
|
|
def forward(self, x, augment=False): |
|
|
|
|
|
|
|
|
def forward(self, x, augment=False, profile=False, visualize=False): |
|
|
y = [] |
|
|
y = [] |
|
|
for module in self: |
|
|
for module in self: |
|
|
y.append(module(x, augment)[0]) |
|
|
|
|
|
|
|
|
y.append(module(x, augment, profile, visualize)[0]) |
|
|
# y = torch.stack(y).max(0)[0] # max ensemble |
|
|
# y = torch.stack(y).max(0)[0] # max ensemble |
|
|
# y = torch.stack(y).mean(0) # mean ensemble |
|
|
# y = torch.stack(y).mean(0) # mean ensemble |
|
|
y = torch.cat(y, 1) # nms ensemble |
|
|
y = torch.cat(y, 1) # nms ensemble |