from models.experimental import attempt_load | from models.experimental import attempt_load | ||||
from models.yolo import Model | from models.yolo import Model | ||||
from utils.downloads import attempt_download | from utils.downloads import attempt_download | ||||
from utils.general import check_requirements, set_logging | |||||
from utils.general import check_requirements, intersect_dicts, set_logging | |||||
from utils.torch_utils import select_device | from utils.torch_utils import select_device | ||||
file = Path(__file__).resolve() | file = Path(__file__).resolve() | ||||
model = Model(cfg, channels, classes) # create model | model = Model(cfg, channels, classes) # create model | ||||
if pretrained: | if pretrained: | ||||
ckpt = torch.load(attempt_download(path), map_location=device) # load | ckpt = torch.load(attempt_download(path), map_location=device) # load | ||||
msd = model.state_dict() # model state_dict | |||||
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 | csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 | ||||
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter | |||||
csd = intersect_dicts(csd, model.state_dict(), exclude=['anchors']) # intersect | |||||
model.load_state_dict(csd, strict=False) # load | model.load_state_dict(csd, strict=False) # load | ||||
if len(ckpt['model'].names) == classes: | if len(ckpt['model'].names) == classes: | ||||
model.names = ckpt['model'].names # set class names attribute | model.names = ckpt['model'].names # set class names attribute |
from utils.downloads import attempt_download | from utils.downloads import attempt_download | ||||
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, | from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, | ||||
check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, | check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, | ||||
labels_to_class_weights, labels_to_image_weights, methods, one_cycle, print_args, | |||||
print_mutation, strip_optimizer) | |||||
intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle, | |||||
print_args, print_mutation, strip_optimizer) | |||||
from utils.loggers import Loggers | from utils.loggers import Loggers | ||||
from utils.loggers.wandb.wandb_utils import check_wandb_resume | from utils.loggers.wandb.wandb_utils import check_wandb_resume | ||||
from utils.loss import ComputeLoss | from utils.loss import ComputeLoss | ||||
from utils.metrics import fitness | from utils.metrics import fitness | ||||
from utils.plots import plot_evolve, plot_labels | from utils.plots import plot_evolve, plot_labels | ||||
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, | |||||
torch_distributed_zero_first) | |||||
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first | |||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html | LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html | ||||
RANK = int(os.getenv('RANK', -1)) | RANK = int(os.getenv('RANK', -1)) |
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False) | cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False) | ||||
def intersect_dicts(da, db, exclude=()): | |||||
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values | |||||
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} | |||||
def get_latest_run(search_dir='.'): | def get_latest_run(search_dir='.'): | ||||
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from) | # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) | ||||
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) | last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) |
return model.module if is_parallel(model) else model | return model.module if is_parallel(model) else model | ||||
def intersect_dicts(da, db, exclude=()): | |||||
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values | |||||
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} | |||||
def initialize_weights(model): | def initialize_weights(model): | ||||
for m in model.modules(): | for m in model.modules(): | ||||
t = type(m) | t = type(m) |