|
|
@@ -47,9 +47,9 @@ from utils.callbacks import Callbacks |
|
|
|
from utils.datasets import create_dataloader |
|
|
|
from utils.downloads import attempt_download |
|
|
|
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, |
|
|
|
check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, |
|
|
|
intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, methods, |
|
|
|
one_cycle, print_args, print_mutation, strip_optimizer) |
|
|
|
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path, |
|
|
|
init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, |
|
|
|
methods, one_cycle, print_args, print_mutation, strip_optimizer) |
|
|
|
from utils.loggers import Loggers |
|
|
|
from utils.loggers.wandb.wandb_utils import check_wandb_resume |
|
|
|
from utils.loss import ComputeLoss |
|
|
@@ -269,7 +269,10 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio |
|
|
|
|
|
|
|
# DDP mode |
|
|
|
if cuda and RANK != -1: |
|
|
|
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) |
|
|
|
if check_version(torch.__version__, '1.11.0'): |
|
|
|
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True) |
|
|
|
else: |
|
|
|
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) |
|
|
|
|
|
|
|
# Model attributes |
|
|
|
nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps) |