Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

146 lines
6.7KB

  1. import json
  2. import shutil
  3. import sys
  4. from datetime import datetime
  5. from pathlib import Path
  6. import torch
  7. sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
  8. from utils.general import colorstr, xywh2xyxy
  9. try:
  10. import wandb
  11. except ImportError:
  12. wandb = None
  13. print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
  14. WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
  15. def remove_prefix(from_string, prefix):
  16. return from_string[len(prefix):]
  17. class WandbLogger():
  18. def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
  19. self.wandb = wandb
  20. self.wandb_run = wandb.init(config=opt, resume="allow",
  21. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  22. name=name,
  23. job_type=job_type,
  24. id=run_id) if self.wandb else None
  25. if job_type == 'Training':
  26. self.setup_training(opt, data_dict)
  27. if opt.bbox_interval == -1:
  28. opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else opt.epochs
  29. if opt.save_period == -1:
  30. opt.save_period = (opt.epochs // 10) if opt.epochs > 10 else opt.epochs
  31. def setup_training(self, opt, data_dict):
  32. self.log_dict = {}
  33. self.train_artifact_path, self.trainset_artifact = \
  34. self.download_dataset_artifact(data_dict['train'], opt.artifact_alias)
  35. self.test_artifact_path, self.testset_artifact = \
  36. self.download_dataset_artifact(data_dict['val'], opt.artifact_alias)
  37. self.result_artifact, self.result_table, self.weights = None, None, None
  38. if self.train_artifact_path is not None:
  39. train_path = Path(self.train_artifact_path) / 'data/images/'
  40. data_dict['train'] = str(train_path)
  41. if self.test_artifact_path is not None:
  42. test_path = Path(self.test_artifact_path) / 'data/images/'
  43. data_dict['val'] = str(test_path)
  44. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  45. self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
  46. if opt.resume_from_artifact:
  47. modeldir, _ = self.download_model_artifact(opt.resume_from_artifact)
  48. if modeldir:
  49. self.weights = Path(modeldir) / "best.pt"
  50. opt.weights = self.weights
  51. def download_dataset_artifact(self, path, alias):
  52. if path.startswith(WANDB_ARTIFACT_PREFIX):
  53. dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
  54. assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
  55. datadir = dataset_artifact.download()
  56. labels_zip = Path(datadir) / "data/labels.zip"
  57. shutil.unpack_archive(labels_zip, Path(datadir) / 'data/labels', 'zip')
  58. print("Downloaded dataset to : ", datadir)
  59. return datadir, dataset_artifact
  60. return None, None
  61. def download_model_artifact(self, name):
  62. model_artifact = wandb.use_artifact(name + ":latest")
  63. assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
  64. modeldir = model_artifact.download()
  65. print("Downloaded model to : ", modeldir)
  66. return modeldir, model_artifact
  67. def log_model(self, path, opt, epoch):
  68. datetime_suffix = datetime.today().strftime('%Y-%m-%d-%H-%M-%S')
  69. model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
  70. 'original_url': str(path),
  71. 'epoch': epoch + 1,
  72. 'save period': opt.save_period,
  73. 'project': opt.project,
  74. 'datetime': datetime_suffix
  75. })
  76. model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
  77. model_artifact.add_file(str(path / 'best.pt'), name='best.pt')
  78. wandb.log_artifact(model_artifact)
  79. print("Saving model artifact on epoch ", epoch + 1)
  80. def log_dataset_artifact(self, dataset, class_to_id, name='dataset'):
  81. artifact = wandb.Artifact(name=name, type="dataset")
  82. image_path = dataset.path
  83. artifact.add_dir(image_path, name='data/images')
  84. table = wandb.Table(columns=["id", "train_image", "Classes"])
  85. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
  86. for si, (img, labels, paths, shapes) in enumerate(dataset):
  87. height, width = shapes[0]
  88. labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4)))
  89. labels[:, 2:] *= torch.Tensor([width, height, width, height])
  90. box_data = []
  91. img_classes = {}
  92. for cls, *xyxy in labels[:, 1:].tolist():
  93. cls = int(cls)
  94. box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
  95. "class_id": cls,
  96. "box_caption": "%s" % (class_to_id[cls]),
  97. "scores": {"acc": 1},
  98. "domain": "pixel"})
  99. img_classes[cls] = class_to_id[cls]
  100. boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
  101. table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes))
  102. artifact.add(table, name)
  103. labels_path = 'labels'.join(image_path.rsplit('images', 1))
  104. zip_path = Path(labels_path).parent / (name + '_labels.zip')
  105. if not zip_path.is_file(): # make_archive won't check if file exists
  106. shutil.make_archive(zip_path.with_suffix(''), 'zip', labels_path)
  107. artifact.add_file(str(zip_path), name='data/labels.zip')
  108. wandb.log_artifact(artifact)
  109. print("Saving data to W&B...")
  110. def log(self, log_dict):
  111. if self.wandb_run:
  112. for key, value in log_dict.items():
  113. self.log_dict[key] = value
  114. def end_epoch(self):
  115. if self.wandb_run and self.log_dict:
  116. wandb.log(self.log_dict)
  117. self.log_dict = {}
  118. def finish_run(self):
  119. if self.wandb_run:
  120. if self.result_artifact:
  121. print("Add Training Progress Artifact")
  122. self.result_artifact.add(self.result_table, 'result')
  123. train_results = wandb.JoinedTable(self.testset_artifact.get("val"), self.result_table, "id")
  124. self.result_artifact.add(train_results, 'joined_result')
  125. wandb.log_artifact(self.result_artifact)
  126. if self.log_dict:
  127. wandb.log(self.log_dict)
  128. wandb.run.finish()