Browse Source

W&B: Update Tables API and comply with new dataset_check (#3772)

* Update tables API and windows path fix

* update dataset check
modifyDataloader
Ayush Chaurasia GitHub 3 years ago
parent
commit
ffb6e11050
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 6 deletions
  1. +11
    -6
      utils/wandb_logging/wandb_utils.py

+ 11
- 6
utils/wandb_logging/wandb_utils.py View File

@@ -136,7 +136,6 @@ class WandbLogger():

def check_and_upload_dataset(self, opt):
assert wandb, 'Install wandb to upload dataset'
check_dataset(self.data_dict)
config_path = self.log_dataset_artifact(check_file(opt.data),
opt.single_cls,
'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
@@ -171,9 +170,11 @@ class WandbLogger():
data_dict['val'] = str(val_path)
self.val_table = self.val_artifact.get("val")
self.map_val_table_path()
wandb.log({"validation dataset": self.val_table})
if self.val_artifact is not None:
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
if opt.bbox_interval == -1:
self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
return data_dict
@@ -181,7 +182,7 @@ class WandbLogger():
def download_dataset_artifact(self, path, alias):
if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):
artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
dataset_artifact = wandb.use_artifact(artifact_path.as_posix())
dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\","/"))
assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
datadir = dataset_artifact.download()
return datadir, dataset_artifact
@@ -216,6 +217,7 @@ class WandbLogger():
def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
with open(data_file) as f:
data = yaml.safe_load(f) # data dict
check_dataset(data)
nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
names = {k: v for k, v in enumerate(names)} # to index dictionary
self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
@@ -228,6 +230,7 @@ class WandbLogger():
data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
data.pop('download', None)
data.pop('path', None)
with open(path, 'w') as f:
yaml.safe_dump(data, f)

@@ -297,6 +300,7 @@ class WandbLogger():
id = self.val_table_map[Path(path).name]
self.result_table.add_data(self.current_epoch,
id,
self.val_table.data[id][1],
wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
total_conf / max(1, len(box_data))
)
@@ -312,11 +316,12 @@ class WandbLogger():
wandb.log(self.log_dict)
self.log_dict = {}
if self.result_artifact:
train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
self.result_artifact.add(train_results, 'result')
self.result_artifact.add(self.result_table, 'result')
wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch),
('best' if best_result else '')])
self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
wandb.log({"evaluation": self.result_table})
self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")

def finish_run(self):

Loading…
Cancel
Save