Browse Source

Implement DDP `static_graph=True` (#6940)

* Implement DDP `static_graph=True`

Experimental implementation of new PyTorch 1.11.0 DDP feature.

* Add 1.11.0 check

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
modifyDataloader
Glenn Jocher GitHub 2 years ago
parent
commit
d95a728f55
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 4 deletions
  1. +7
    -4
      train.py

+ 7
- 4
train.py View File

from utils.datasets import create_dataloader from utils.datasets import create_dataloader
from utils.downloads import attempt_download from utils.downloads import attempt_download
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, methods,
one_cycle, print_args, print_mutation, strip_optimizer)
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights,
methods, one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss from utils.loss import ComputeLoss


# DDP mode # DDP mode
if cuda and RANK != -1: if cuda and RANK != -1:
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
if check_version(torch.__version__, '1.11.0'):
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
else:
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)


# Model attributes # Model attributes
nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps) nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)

Loading…
Cancel
Save