|
|
@@ -108,13 +108,15 @@ class ComputeLoss: |
|
|
|
if g > 0: |
|
|
|
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) |
|
|
|
|
|
|
|
det = de_parallel(model).model[-1] # Detect() module |
|
|
|
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7 |
|
|
|
self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index |
|
|
|
m = de_parallel(model).model[-1] # Detect() module |
|
|
|
self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7 |
|
|
|
self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index |
|
|
|
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance |
|
|
|
self.na = m.na # number of anchors |
|
|
|
self.nc = m.nc # number of classes |
|
|
|
self.nl = m.nl # number of layers |
|
|
|
self.anchors = m.anchors |
|
|
|
self.device = device |
|
|
|
for k in 'na', 'nc', 'nl', 'anchors': |
|
|
|
setattr(self, k, getattr(det, k)) |
|
|
|
|
|
|
|
def __call__(self, p, targets): # predictions, targets |
|
|
|
lcls = torch.zeros(1, device=self.device) # class loss |
|
|
@@ -129,7 +131,8 @@ class ComputeLoss: |
|
|
|
|
|
|
|
n = b.shape[0] # number of targets |
|
|
|
if n: |
|
|
|
pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # target-subset of predictions |
|
|
|
# pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # faster, requires torch 1.8.0 |
|
|
|
pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1) # target-subset of predictions |
|
|
|
|
|
|
|
# Regression |
|
|
|
pxy = pxy.sigmoid() * 2 - 0.5 |