瀏覽代碼

Update train.py forward simplification

5.0
Glenn Jocher 4 年之前
父節點
當前提交
a21bd0687c
共有 1 個文件被更改,包括 3 次插入9 次删除
  1. +3
    -9
      train.py

+ 3
- 9
train.py 查看文件

@@ -265,18 +265,12 @@ def train(hyp, opt, device, tb_writer=None):
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

# Autocast
# Forward
with amp.autocast(enabled=cuda):
# Forward
pred = model(imgs)

# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
# if not torch.isfinite(loss):
# logger.info('WARNING: non-finite loss, ending training ', loss_items)
# return results

# Backward
scaler.scale(loss).backward()

Loading…
取消
儲存