Browse Source

opt.img_weights bug fix (#885)

5.0
Glenn Jocher 4 years ago
parent
commit
69ff781ca5
1 changed files with 8 additions and 10 deletions
  1. +8
    -10
      train.py

+ 8
- 10
train.py View File

@@ -216,18 +216,15 @@ def train(hyp, opt, device, tb_writer=None):
model.train()

# Update image weights (optional)
if dataset.image_weights:
if opt.img_weights:
# Generate indices
if rank in [-1, 0]:
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
dataset.indices = random.choices(range(dataset.n), weights=image_weights,
k=dataset.n) # rand weighted idx
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
# Broadcast if DDP
if rank != -1:
indices = torch.zeros([dataset.n], dtype=torch.int)
if rank == 0:
indices[:] = torch.tensor(dataset.indices, dtype=torch.int)
indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
dist.broadcast(indices, 0)
if rank != 0:
dataset.indices = indices.cpu().numpy()
@@ -388,7 +385,8 @@ if __name__ == '__main__':
parser.add_argument('--hyp', type=str, default='', help='hyperparameters path, i.e. data/hyp.scratch.yaml')
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
parser.add_argument('--img-weights', action='store_true', help='use weighted image selection for training')
parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
@@ -471,7 +469,7 @@ if __name__ == '__main__':
'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
'iou_t': (0, 0.1, 0.7), # IoU training threshold
'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
# 'anchors': (1, 2.0, 10.0), # anchors per output grid (0 to ignore)
# 'anchors': (1, 2.0, 10.0), # anchors per output grid (0 to ignore)
'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)

Loading…
Cancel
Save