Browse Source

Merge remote-tracking branch 'origin/master'

5.0
Glenn Jocher 4 years ago
parent
commit
5e2429e618
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      utils/utils.py

+ 3
- 3
utils/utils.py View File

BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)


# per output # per output
nt = 0 # targets
nt = 0 # number of targets
np = len(p) # number of outputs
balance = [1.0, 1.0, 1.0] balance = [1.0, 1.0, 1.0]
for i, pi in enumerate(p): # layer index, layer predictions for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx b, a, gj, gi = indices[i] # image, anchor, gridy, gridx


lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss


s = 3 / (i + 1) # output count scaling
s = 3 / np # output count scaling
lbox *= h['giou'] * s lbox *= h['giou'] * s
lobj *= h['obj'] * s lobj *= h['obj'] * s
lcls *= h['cls'] * s lcls *= h['cls'] * s
j, k = ((gxy % 1. < g) & (gxy > 1.)).T j, k = ((gxy % 1. < g) & (gxy > 1.)).T
a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0) a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g

elif style == 'rect4': elif style == 'rect4':
j, k = ((gxy % 1. < g) & (gxy > 1.)).T j, k = ((gxy % 1. < g) & (gxy > 1.)).T
l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T

Loading…
Cancel
Save