Browse Source

`intersect_dicts()` in hubconf.py fix (#5542)

modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
e189fa15ea
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 10 additions and 12 deletions
  1. +2
    -3
      hubconf.py
  2. +3
    -4
      train.py
  3. +5
    -0
      utils/general.py
  4. +0
    -5
      utils/torch_utils.py

+ 2
- 3
hubconf.py View File

from models.experimental import attempt_load from models.experimental import attempt_load
from models.yolo import Model from models.yolo import Model
from utils.downloads import attempt_download from utils.downloads import attempt_download
from utils.general import check_requirements, set_logging
from utils.general import check_requirements, intersect_dicts, set_logging
from utils.torch_utils import select_device from utils.torch_utils import select_device


file = Path(__file__).resolve() file = Path(__file__).resolve()
model = Model(cfg, channels, classes) # create model model = Model(cfg, channels, classes) # create model
if pretrained: if pretrained:
ckpt = torch.load(attempt_download(path), map_location=device) # load ckpt = torch.load(attempt_download(path), map_location=device) # load
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
csd = intersect_dicts(csd, model.state_dict(), exclude=['anchors']) # intersect
model.load_state_dict(csd, strict=False) # load model.load_state_dict(csd, strict=False) # load
if len(ckpt['model'].names) == classes: if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute model.names = ckpt['model'].names # set class names attribute

+ 3
- 4
train.py View File

from utils.downloads import attempt_download from utils.downloads import attempt_download
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
labels_to_class_weights, labels_to_image_weights, methods, one_cycle, print_args,
print_mutation, strip_optimizer)
intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
from utils.metrics import fitness from utils.metrics import fitness
from utils.plots import plot_evolve, plot_labels from utils.plots import plot_evolve, plot_labels
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device,
torch_distributed_zero_first)
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first


LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))

+ 5
- 0
utils/general.py View File

cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False) cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)




def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}


def get_latest_run(search_dir='.'): def get_latest_run(search_dir='.'):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from) # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)

+ 0
- 5
utils/torch_utils.py View File

return model.module if is_parallel(model) else model return model.module if is_parallel(model) else model




def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}


def initialize_weights(model): def initialize_weights(model):
for m in model.modules(): for m in model.modules():
t = type(m) t = type(m)

Loading…
Cancel
Save