You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

342 lines
17KB

  1. """Utilities and tools for tracking runs with Weights & Biases."""
  2. import logging
  3. import sys
  4. from contextlib import contextmanager
  5. from pathlib import Path
  6. import yaml
  7. from tqdm import tqdm
  8. sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
  9. from utils.datasets import LoadImagesAndLabels
  10. from utils.datasets import img2label_paths
  11. from utils.general import colorstr, check_dataset, check_file
  12. try:
  13. import wandb
  14. from wandb import init, finish
  15. except ImportError:
  16. wandb = None
  17. WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
  18. def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX):
  19. return from_string[len(prefix):]
  20. def check_wandb_config_file(data_config_file):
  21. wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path
  22. if Path(wandb_config).is_file():
  23. return wandb_config
  24. return data_config_file
  25. def get_run_info(run_path):
  26. run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
  27. run_id = run_path.stem
  28. project = run_path.parent.stem
  29. entity = run_path.parent.parent.stem
  30. model_artifact_name = 'run_' + run_id + '_model'
  31. return entity, project, run_id, model_artifact_name
  32. def check_wandb_resume(opt):
  33. process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None
  34. if isinstance(opt.resume, str):
  35. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  36. if opt.global_rank not in [-1, 0]: # For resuming DDP runs
  37. entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
  38. api = wandb.Api()
  39. artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
  40. modeldir = artifact.download()
  41. opt.weights = str(Path(modeldir) / "last.pt")
  42. return True
  43. return None
  44. def process_wandb_config_ddp_mode(opt):
  45. with open(check_file(opt.data)) as f:
  46. data_dict = yaml.safe_load(f) # data dict
  47. train_dir, val_dir = None, None
  48. if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX):
  49. api = wandb.Api()
  50. train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias)
  51. train_dir = train_artifact.download()
  52. train_path = Path(train_dir) / 'data/images/'
  53. data_dict['train'] = str(train_path)
  54. if isinstance(data_dict['val'], str) and data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX):
  55. api = wandb.Api()
  56. val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias)
  57. val_dir = val_artifact.download()
  58. val_path = Path(val_dir) / 'data/images/'
  59. data_dict['val'] = str(val_path)
  60. if train_dir or val_dir:
  61. ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml')
  62. with open(ddp_data_path, 'w') as f:
  63. yaml.safe_dump(data_dict, f)
  64. opt.data = ddp_data_path
  65. class WandbLogger():
  66. """Log training runs, datasets, models, and predictions to Weights & Biases.
  67. This logger sends information to W&B at wandb.ai. By default, this information
  68. includes hyperparameters, system configuration and metrics, model metrics,
  69. and basic data metrics and analyses.
  70. By providing additional command line arguments to train.py, datasets,
  71. models and predictions can also be logged.
  72. For more on how this logger is used, see the Weights & Biases documentation:
  73. https://docs.wandb.com/guides/integrations/yolov5
  74. """
  75. def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
  76. # Pre-training routine --
  77. self.job_type = job_type
  78. self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
  79. # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
  80. if isinstance(opt.resume, str): # checks resume from artifact
  81. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  82. entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
  83. model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
  84. assert wandb, 'install wandb to resume wandb runs'
  85. # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
  86. self.wandb_run = wandb.init(id=run_id,
  87. project=project,
  88. entity=entity,
  89. resume='allow',
  90. allow_val_change=True)
  91. opt.resume = model_artifact_name
  92. elif self.wandb:
  93. self.wandb_run = wandb.init(config=opt,
  94. resume="allow",
  95. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  96. entity=opt.entity,
  97. name=name,
  98. job_type=job_type,
  99. id=run_id,
  100. allow_val_change=True) if not wandb.run else wandb.run
  101. if self.wandb_run:
  102. if self.job_type == 'Training':
  103. if not opt.resume:
  104. wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
  105. # Info useful for resuming from artifacts
  106. self.wandb_run.config.opt = vars(opt)
  107. self.wandb_run.config.data_dict = wandb_data_dict
  108. self.data_dict = self.setup_training(opt, data_dict)
  109. if self.job_type == 'Dataset Creation':
  110. self.data_dict = self.check_and_upload_dataset(opt)
  111. else:
  112. prefix = colorstr('wandb: ')
  113. print(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
  114. def check_and_upload_dataset(self, opt):
  115. assert wandb, 'Install wandb to upload dataset'
  116. check_dataset(self.data_dict)
  117. config_path = self.log_dataset_artifact(check_file(opt.data),
  118. opt.single_cls,
  119. 'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
  120. print("Created dataset config file ", config_path)
  121. with open(config_path) as f:
  122. wandb_data_dict = yaml.safe_load(f)
  123. return wandb_data_dict
  124. def setup_training(self, opt, data_dict):
  125. self.log_dict, self.current_epoch, self.log_imgs = {}, 0, 16 # Logging Constants
  126. self.bbox_interval = opt.bbox_interval
  127. if isinstance(opt.resume, str):
  128. modeldir, _ = self.download_model_artifact(opt)
  129. if modeldir:
  130. self.weights = Path(modeldir) / "last.pt"
  131. config = self.wandb_run.config
  132. opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
  133. self.weights), config.save_period, config.total_batch_size, config.bbox_interval, config.epochs, \
  134. config.opt['hyp']
  135. data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
  136. if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download
  137. self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
  138. opt.artifact_alias)
  139. self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
  140. opt.artifact_alias)
  141. self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None
  142. if self.train_artifact_path is not None:
  143. train_path = Path(self.train_artifact_path) / 'data/images/'
  144. data_dict['train'] = str(train_path)
  145. if self.val_artifact_path is not None:
  146. val_path = Path(self.val_artifact_path) / 'data/images/'
  147. data_dict['val'] = str(val_path)
  148. self.val_table = self.val_artifact.get("val")
  149. self.map_val_table_path()
  150. if self.val_artifact is not None:
  151. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  152. self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
  153. if opt.bbox_interval == -1:
  154. self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
  155. return data_dict
  156. def download_dataset_artifact(self, path, alias):
  157. if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):
  158. artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
  159. dataset_artifact = wandb.use_artifact(artifact_path.as_posix())
  160. assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
  161. datadir = dataset_artifact.download()
  162. return datadir, dataset_artifact
  163. return None, None
  164. def download_model_artifact(self, opt):
  165. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  166. model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
  167. assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
  168. modeldir = model_artifact.download()
  169. epochs_trained = model_artifact.metadata.get('epochs_trained')
  170. total_epochs = model_artifact.metadata.get('total_epochs')
  171. is_finished = total_epochs is None
  172. assert not is_finished, 'training is finished, can only resume incomplete runs.'
  173. return modeldir, model_artifact
  174. return None, None
  175. def log_model(self, path, opt, epoch, fitness_score, best_model=False):
  176. model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
  177. 'original_url': str(path),
  178. 'epochs_trained': epoch + 1,
  179. 'save period': opt.save_period,
  180. 'project': opt.project,
  181. 'total_epochs': opt.epochs,
  182. 'fitness_score': fitness_score
  183. })
  184. model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
  185. wandb.log_artifact(model_artifact,
  186. aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
  187. print("Saving model artifact on epoch ", epoch + 1)
  188. def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
  189. with open(data_file) as f:
  190. data = yaml.safe_load(f) # data dict
  191. nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
  192. names = {k: v for k, v in enumerate(names)} # to index dictionary
  193. self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
  194. data['train'], rect=True, batch_size=1), names, name='train') if data.get('train') else None
  195. self.val_artifact = self.create_dataset_table(LoadImagesAndLabels(
  196. data['val'], rect=True, batch_size=1), names, name='val') if data.get('val') else None
  197. if data.get('train'):
  198. data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
  199. if data.get('val'):
  200. data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
  201. path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
  202. data.pop('download', None)
  203. with open(path, 'w') as f:
  204. yaml.safe_dump(data, f)
  205. if self.job_type == 'Training': # builds correct artifact pipeline graph
  206. self.wandb_run.use_artifact(self.val_artifact)
  207. self.wandb_run.use_artifact(self.train_artifact)
  208. self.val_artifact.wait()
  209. self.val_table = self.val_artifact.get('val')
  210. self.map_val_table_path()
  211. else:
  212. self.wandb_run.log_artifact(self.train_artifact)
  213. self.wandb_run.log_artifact(self.val_artifact)
  214. return path
  215. def map_val_table_path(self):
  216. self.val_table_map = {}
  217. print("Mapping dataset")
  218. for i, data in enumerate(tqdm(self.val_table.data)):
  219. self.val_table_map[data[3]] = data[0]
  220. def create_dataset_table(self, dataset, class_to_id, name='dataset'):
  221. # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
  222. artifact = wandb.Artifact(name=name, type="dataset")
  223. img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None
  224. img_files = tqdm(dataset.img_files) if not img_files else img_files
  225. for img_file in img_files:
  226. if Path(img_file).is_dir():
  227. artifact.add_dir(img_file, name='data/images')
  228. labels_path = 'labels'.join(dataset.path.rsplit('images', 1))
  229. artifact.add_dir(labels_path, name='data/labels')
  230. else:
  231. artifact.add_file(img_file, name='data/images/' + Path(img_file).name)
  232. label_file = Path(img2label_paths([img_file])[0])
  233. artifact.add_file(str(label_file),
  234. name='data/labels/' + label_file.name) if label_file.exists() else None
  235. table = wandb.Table(columns=["id", "train_image", "Classes", "name"])
  236. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
  237. for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)):
  238. box_data, img_classes = [], {}
  239. for cls, *xywh in labels[:, 1:].tolist():
  240. cls = int(cls)
  241. box_data.append({"position": {"middle": [xywh[0], xywh[1]], "width": xywh[2], "height": xywh[3]},
  242. "class_id": cls,
  243. "box_caption": "%s" % (class_to_id[cls])})
  244. img_classes[cls] = class_to_id[cls]
  245. boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
  246. table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), list(img_classes.values()),
  247. Path(paths).name)
  248. artifact.add(table, name)
  249. return artifact
  250. def log_training_progress(self, predn, path, names):
  251. if self.val_table and self.result_table:
  252. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
  253. box_data = []
  254. total_conf = 0
  255. for *xyxy, conf, cls in predn.tolist():
  256. if conf >= 0.25:
  257. box_data.append(
  258. {"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
  259. "class_id": int(cls),
  260. "box_caption": "%s %.3f" % (names[cls], conf),
  261. "scores": {"class_score": conf},
  262. "domain": "pixel"})
  263. total_conf = total_conf + conf
  264. boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
  265. id = self.val_table_map[Path(path).name]
  266. self.result_table.add_data(self.current_epoch,
  267. id,
  268. wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
  269. total_conf / max(1, len(box_data))
  270. )
  271. def log(self, log_dict):
  272. if self.wandb_run:
  273. for key, value in log_dict.items():
  274. self.log_dict[key] = value
  275. def end_epoch(self, best_result=False):
  276. if self.wandb_run:
  277. with all_logging_disabled():
  278. wandb.log(self.log_dict)
  279. self.log_dict = {}
  280. if self.result_artifact:
  281. train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
  282. self.result_artifact.add(train_results, 'result')
  283. wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch),
  284. ('best' if best_result else '')])
  285. self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
  286. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  287. def finish_run(self):
  288. if self.wandb_run:
  289. if self.log_dict:
  290. with all_logging_disabled():
  291. wandb.log(self.log_dict)
  292. wandb.run.finish()
  293. @contextmanager
  294. def all_logging_disabled(highest_level=logging.CRITICAL):
  295. """ source - https://gist.github.com/simon-weber/7853144
  296. A context manager that will prevent any logging messages triggered during the body from being processed.
  297. :param highest_level: the maximum logging level in use.
  298. This would only need to be changed if a custom level greater than CRITICAL is defined.
  299. """
  300. previous_level = logging.root.manager.disable
  301. logging.disable(highest_level)
  302. try:
  303. yield
  304. finally:
  305. logging.disable(previous_level)