Browse Source

Resume with custom anchors fix (#2361)

* Resume with custom anchors fix

* Update train.py
5.0
Glenn Jocher GitHub 3 years ago
parent
commit
e931b9da33
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 4 deletions
  1. +3
    -4
      train.py

+ 3
- 4
train.py View File

with torch_distributed_zero_first(rank): with torch_distributed_zero_first(rank):
attempt_download(weights) # download if not found locally attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint ckpt = torch.load(weights, map_location=device) # load checkpoint
if hyp.get('anchors'):
ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
state_dict = ckpt['model'].float().state_dict() # to FP32 state_dict = ckpt['model'].float().state_dict() # to FP32
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
model.load_state_dict(state_dict, strict=False) # load model.load_state_dict(state_dict, strict=False) # load
# Anchors # Anchors
if not opt.noautoanchor: if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
model.half().float() # pre-reduce anchor precision


# Model parameters # Model parameters
hyp['box'] *= 3. / nl # scale to layers hyp['box'] *= 3. / nl # scale to layers

Loading…
Cancel
Save