|
|
@@ -3,10 +3,9 @@ |
|
|
|
import logging |
|
|
|
import os |
|
|
|
import sys |
|
|
|
import yaml |
|
|
|
from contextlib import contextmanager |
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
import yaml |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
FILE = Path(__file__).absolute() |
|
|
@@ -99,7 +98,7 @@ class WandbLogger(): |
|
|
|
https://docs.wandb.com/guides/integrations/yolov5 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, opt, run_id, data_dict, job_type='Training'): |
|
|
|
def __init__(self, opt, run_id, job_type='Training'): |
|
|
|
""" |
|
|
|
- Initialize WandbLogger instance |
|
|
|
- Upload dataset if opt.upload_dataset is True |
|
|
@@ -108,7 +107,6 @@ class WandbLogger(): |
|
|
|
arguments: |
|
|
|
opt (namespace) -- Commandline arguments for this run |
|
|
|
run_id (str) -- Run ID of W&B run to be resumed |
|
|
|
data_dict (Dict) -- Dictionary conataining info about the dataset to be used |
|
|
|
job_type (str) -- To set the job_type for this run |
|
|
|
|
|
|
|
""" |
|
|
@@ -119,10 +117,11 @@ class WandbLogger(): |
|
|
|
self.train_artifact_path, self.val_artifact_path = None, None |
|
|
|
self.result_artifact = None |
|
|
|
self.val_table, self.result_table = None, None |
|
|
|
self.data_dict = data_dict |
|
|
|
self.bbox_media_panel_images = [] |
|
|
|
self.val_table_path_map = None |
|
|
|
self.max_imgs_to_log = 16 |
|
|
|
self.wandb_artifact_data_dict = None |
|
|
|
self.data_dict = None |
|
|
|
# It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call |
|
|
|
if isinstance(opt.resume, str): # checks resume from artifact |
|
|
|
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): |
|
|
@@ -148,11 +147,23 @@ class WandbLogger(): |
|
|
|
if self.wandb_run: |
|
|
|
if self.job_type == 'Training': |
|
|
|
if not opt.resume: |
|
|
|
wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict |
|
|
|
# Info useful for resuming from artifacts |
|
|
|
self.wandb_run.config.update({'opt': vars(opt), 'data_dict': wandb_data_dict}, |
|
|
|
allow_val_change=True) |
|
|
|
self.data_dict = self.setup_training(opt, data_dict) |
|
|
|
if opt.upload_dataset: |
|
|
|
self.wandb_artifact_data_dict = self.check_and_upload_dataset(opt) |
|
|
|
|
|
|
|
elif opt.data.endswith('_wandb.yaml'): # When dataset is W&B artifact |
|
|
|
with open(opt.data, encoding='ascii', errors='ignore') as f: |
|
|
|
data_dict = yaml.safe_load(f) |
|
|
|
self.data_dict = data_dict |
|
|
|
else: # Local .yaml dataset file or .zip file |
|
|
|
self.data_dict = check_dataset(opt.data) |
|
|
|
|
|
|
|
self.setup_training(opt) |
|
|
|
# write data_dict to config. useful for resuming from artifacts |
|
|
|
if not self.wandb_artifact_data_dict: |
|
|
|
self.wandb_artifact_data_dict = self.data_dict |
|
|
|
self.wandb_run.config.update({'data_dict': self.wandb_artifact_data_dict}, |
|
|
|
allow_val_change=True) |
|
|
|
|
|
|
|
if self.job_type == 'Dataset Creation': |
|
|
|
self.data_dict = self.check_and_upload_dataset(opt) |
|
|
|
|
|
|
@@ -167,7 +178,7 @@ class WandbLogger(): |
|
|
|
Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links. |
|
|
|
""" |
|
|
|
assert wandb, 'Install wandb to upload dataset' |
|
|
|
config_path = self.log_dataset_artifact(check_file(opt.data), |
|
|
|
config_path = self.log_dataset_artifact(opt.data, |
|
|
|
opt.single_cls, |
|
|
|
'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem) |
|
|
|
print("Created dataset config file ", config_path) |
|
|
@@ -175,7 +186,7 @@ class WandbLogger(): |
|
|
|
wandb_data_dict = yaml.safe_load(f) |
|
|
|
return wandb_data_dict |
|
|
|
|
|
|
|
def setup_training(self, opt, data_dict): |
|
|
|
def setup_training(self, opt): |
|
|
|
""" |
|
|
|
Setup the necessary processes for training YOLO models: |
|
|
|
- Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX |
|
|
@@ -184,10 +195,7 @@ class WandbLogger(): |
|
|
|
|
|
|
|
arguments: |
|
|
|
opt (namespace) -- commandline arguments for this run |
|
|
|
data_dict (Dict) -- Dataset dictionary for this run |
|
|
|
|
|
|
|
returns: |
|
|
|
data_dict (Dict) -- contains the updated info about the dataset to be used for training |
|
|
|
""" |
|
|
|
self.log_dict, self.current_epoch = {}, 0 |
|
|
|
self.bbox_interval = opt.bbox_interval |
|
|
@@ -198,8 +206,10 @@ class WandbLogger(): |
|
|
|
config = self.wandb_run.config |
|
|
|
opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str( |
|
|
|
self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \ |
|
|
|
config.opt['hyp'] |
|
|
|
config.hyp |
|
|
|
data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume |
|
|
|
else: |
|
|
|
data_dict = self.data_dict |
|
|
|
if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download |
|
|
|
self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'), |
|
|
|
opt.artifact_alias) |
|
|
@@ -221,7 +231,10 @@ class WandbLogger(): |
|
|
|
self.map_val_table_path() |
|
|
|
if opt.bbox_interval == -1: |
|
|
|
self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1 |
|
|
|
return data_dict |
|
|
|
train_from_artifact = self.train_artifact_path is not None and self.val_artifact_path is not None |
|
|
|
# Update the the data_dict to point to local artifacts dir |
|
|
|
if train_from_artifact: |
|
|
|
self.data_dict = data_dict |
|
|
|
|
|
|
|
def download_dataset_artifact(self, path, alias): |
|
|
|
""" |
|
|
@@ -299,7 +312,8 @@ class WandbLogger(): |
|
|
|
returns: |
|
|
|
the new .yaml file with artifact links. it can be used to start training directly from artifacts |
|
|
|
""" |
|
|
|
data = check_dataset(data_file) # parse and check |
|
|
|
self.data_dict = check_dataset(data_file) # parse and check |
|
|
|
data = dict(self.data_dict) |
|
|
|
nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names']) |
|
|
|
names = {k: v for k, v in enumerate(names)} # to index dictionary |
|
|
|
self.train_artifact = self.create_dataset_table(LoadImagesAndLabels( |
|
|
@@ -310,7 +324,8 @@ class WandbLogger(): |
|
|
|
data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train') |
|
|
|
if data.get('val'): |
|
|
|
data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val') |
|
|
|
path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path |
|
|
|
path = Path(data_file).stem |
|
|
|
path = (path if overwrite_config else path + '_wandb') + '.yaml' # updated data.yaml path |
|
|
|
data.pop('download', None) |
|
|
|
data.pop('path', None) |
|
|
|
with open(path, 'w') as f: |