瀏覽代碼

Remove DDP MultiHeadAttention fix (#3768)

modifyDataloader
Glenn Jocher GitHub 3 年之前
父節點
當前提交
f2d97ebb25
沒有發現已知的金鑰在資料庫的簽署中 GPG Key ID: 4AEE18F83AFDEB23
共有 1 個文件被更改,包括 1 次插入3 次删除
  1. +1
    -3
      train.py

+ 1
- 3
train.py 查看文件

@@ -252,9 +252,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# DDP mode
if cuda and RANK != -1:
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK,
# nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698
find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules()))
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)

# Model parameters
hyp['box'] *= 3. / nl # scale to layers

Loading…
取消
儲存