Browse Source

offset and balance update

5.0
Glenn Jocher 4 years ago
parent
commit
f767023c56
1 changed files with 10 additions and 9 deletions
  1. +10
    -9
      utils/utils.py

+ 10
- 9
utils/utils.py View File



# per output # per output
nt = 0 # targets nt = 0 # targets
balance = [1.0, 1.0, 1.0]
for i, pi in enumerate(p): # layer index, layer predictions for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tobj = torch.zeros_like(pi[..., 0]) # target obj tobj = torch.zeros_like(pi[..., 0]) # target obj
# with open('targets.txt', 'a') as file: # with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]


lobj += BCEobj(pi[..., 4], tobj) # obj loss
lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss


lbox *= h['giou']
lobj *= h['obj']
lcls *= h['cls']
s = 3 / (i + 1) # output count scaling
lbox *= h['giou'] * s
lobj *= h['obj'] * s
lcls *= h['cls'] * s
bs = tobj.shape[0] # batch size bs = tobj.shape[0] # batch size
if red == 'sum': if red == 'sum':
g = 3.0 # loss gain g = 3.0 # loss gain
a, t = at[j], t.repeat(na, 1, 1)[j] # filter a, t = at[j], t.repeat(na, 1, 1)[j] # filter


# overlaps # overlaps
g = 0.5 # offset
gxy = t[:, 2:4] # grid xy gxy = t[:, 2:4] # grid xy
z = torch.zeros_like(gxy) z = torch.zeros_like(gxy)
if style == 'rect2': if style == 'rect2':
g = 0.2 # offset
j, k = ((gxy % 1. < g) & (gxy > 1.)).T j, k = ((gxy % 1. < g) & (gxy > 1.)).T
a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0) a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g


elif style == 'rect4': elif style == 'rect4':
g = 0.5 # offset
j, k = ((gxy % 1. < g) & (gxy > 1.)).T j, k = ((gxy % 1. < g) & (gxy > 1.)).T
l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0) a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0)
wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh


# Filter # Filter
i = (wh0 < 4.0).any(1).sum()
i = (wh0 < 3.0).any(1).sum()
if i: if i:
print('WARNING: Extremely small objects found. ' print('WARNING: Extremely small objects found. '
'%g of %g labels are < 4 pixels in width or height.' % (i, len(wh0)))
wh = wh0[(wh0 >= 4.0).any(1)] # filter > 2 pixels
'%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels


# Kmeans calculation # Kmeans calculation
from scipy.cluster.vq import kmeans from scipy.cluster.vq import kmeans

Loading…
Cancel
Save