|
|
@@ -120,7 +120,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
if pretrained: |
|
|
|
with torch_distributed_zero_first(LOCAL_RANK): |
|
|
|
weights = attempt_download(weights) # download if not found locally |
|
|
|
ckpt = torch.load(weights, map_location=device) # load checkpoint |
|
|
|
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak |
|
|
|
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create |
|
|
|
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys |
|
|
|
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 |