|
|
@@ -75,10 +75,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): |
|
|
|
with torch_distributed_zero_first(rank): |
|
|
|
attempt_download(weights) # download if not found locally |
|
|
|
ckpt = torch.load(weights, map_location=device) # load checkpoint |
|
|
|
if hyp.get('anchors'): |
|
|
|
ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor |
|
|
|
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create |
|
|
|
exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys |
|
|
|
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create |
|
|
|
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys |
|
|
|
state_dict = ckpt['model'].float().state_dict() # to FP32 |
|
|
|
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect |
|
|
|
model.load_state_dict(state_dict, strict=False) # load |
|
|
@@ -216,6 +214,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): |
|
|
|
# Anchors |
|
|
|
if not opt.noautoanchor: |
|
|
|
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) |
|
|
|
model.half().float() # pre-reduce anchor precision |
|
|
|
|
|
|
|
# Model parameters |
|
|
|
hyp['box'] *= 3. / nl # scale to layers |