`Ensemble()` visualize fix (#3973)
* fix visualize error * Revert "fix visualize error" * add visualise profile
This commit is contained in:
parent
a544d59f52
commit
647223a7a8
|
|
@ -100,10 +100,10 @@ class Ensemble(nn.ModuleList):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Ensemble, self).__init__()
|
super(Ensemble, self).__init__()
|
||||||
|
|
||||||
def forward(self, x, augment=False):
|
def forward(self, x, augment=False, profile=False, visualize=False):
|
||||||
y = []
|
y = []
|
||||||
for module in self:
|
for module in self:
|
||||||
y.append(module(x, augment)[0])
|
y.append(module(x, augment, profile, visualize)[0])
|
||||||
# y = torch.stack(y).max(0)[0] # max ensemble
|
# y = torch.stack(y).max(0)[0] # max ensemble
|
||||||
# y = torch.stack(y).mean(0) # mean ensemble
|
# y = torch.stack(y).mean(0) # mean ensemble
|
||||||
y = torch.cat(y, 1) # nms ensemble
|
y = torch.cat(y, 1) # nms ensemble
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue