Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

268 rindas
14KB

  1. import argparse
  2. import json
  3. import os
  4. import shutil
  5. import sys
  6. import torch
  7. import yaml
  8. from datetime import datetime
  9. from pathlib import Path
  10. from tqdm import tqdm
  11. sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
  12. from utils.datasets import LoadImagesAndLabels
  13. from utils.datasets import img2label_paths
  14. from utils.general import colorstr, xywh2xyxy, check_dataset
  15. try:
  16. import wandb
  17. from wandb import init, finish
  18. except ImportError:
  19. wandb = None
  20. WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
  21. def remove_prefix(from_string, prefix):
  22. return from_string[len(prefix):]
  23. def check_wandb_config_file(data_config_file):
  24. wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path
  25. if Path(wandb_config).is_file():
  26. return wandb_config
  27. return data_config_file
  28. def resume_and_get_id(opt):
  29. # It's more elegant to stick to 1 wandb.init call, but as useful config data is overwritten in the WandbLogger's wandb.init call
  30. if isinstance(opt.resume, str):
  31. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  32. run_path = Path(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX))
  33. run_id = run_path.stem
  34. project = run_path.parent.stem
  35. model_artifact_name = WANDB_ARTIFACT_PREFIX + 'run_' + run_id + '_model'
  36. assert wandb, 'install wandb to resume wandb runs'
  37. # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
  38. run = wandb.init(id=run_id, project=project, resume='allow')
  39. opt.resume = model_artifact_name
  40. return run
  41. return None
  42. class WandbLogger():
  43. def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
  44. # Pre-training routine --
  45. self.job_type = job_type
  46. self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
  47. if self.wandb:
  48. self.wandb_run = wandb.init(config=opt,
  49. resume="allow",
  50. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  51. name=name,
  52. job_type=job_type,
  53. id=run_id) if not wandb.run else wandb.run
  54. if self.job_type == 'Training':
  55. if not opt.resume:
  56. wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
  57. # Info useful for resuming from artifacts
  58. self.wandb_run.config.opt = vars(opt)
  59. self.wandb_run.config.data_dict = wandb_data_dict
  60. self.data_dict = self.setup_training(opt, data_dict)
  61. if self.job_type == 'Dataset Creation':
  62. self.data_dict = self.check_and_upload_dataset(opt)
  63. else:
  64. print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
  65. def check_and_upload_dataset(self, opt):
  66. assert wandb, 'Install wandb to upload dataset'
  67. check_dataset(self.data_dict)
  68. config_path = self.log_dataset_artifact(opt.data,
  69. opt.single_cls,
  70. 'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
  71. print("Created dataset config file ", config_path)
  72. with open(config_path) as f:
  73. wandb_data_dict = yaml.load(f, Loader=yaml.SafeLoader)
  74. return wandb_data_dict
  75. def setup_training(self, opt, data_dict):
  76. self.log_dict, self.current_epoch, self.log_imgs = {}, 0, 16 # Logging Constants
  77. self.bbox_interval = opt.bbox_interval
  78. if isinstance(opt.resume, str):
  79. modeldir, _ = self.download_model_artifact(opt)
  80. if modeldir:
  81. self.weights = Path(modeldir) / "last.pt"
  82. config = self.wandb_run.config
  83. opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
  84. self.weights), config.save_period, config.total_batch_size, config.bbox_interval, config.epochs, \
  85. config.opt['hyp']
  86. data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
  87. if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download
  88. self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
  89. opt.artifact_alias)
  90. self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
  91. opt.artifact_alias)
  92. self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None
  93. if self.train_artifact_path is not None:
  94. train_path = Path(self.train_artifact_path) / 'data/images/'
  95. data_dict['train'] = str(train_path)
  96. if self.val_artifact_path is not None:
  97. val_path = Path(self.val_artifact_path) / 'data/images/'
  98. data_dict['val'] = str(val_path)
  99. self.val_table = self.val_artifact.get("val")
  100. self.map_val_table_path()
  101. if self.val_artifact is not None:
  102. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  103. self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
  104. if opt.bbox_interval == -1:
  105. self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
  106. return data_dict
  107. def download_dataset_artifact(self, path, alias):
  108. if path.startswith(WANDB_ARTIFACT_PREFIX):
  109. dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
  110. assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
  111. datadir = dataset_artifact.download()
  112. return datadir, dataset_artifact
  113. return None, None
  114. def download_model_artifact(self, opt):
  115. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  116. model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
  117. assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
  118. modeldir = model_artifact.download()
  119. epochs_trained = model_artifact.metadata.get('epochs_trained')
  120. total_epochs = model_artifact.metadata.get('total_epochs')
  121. assert epochs_trained < total_epochs, 'training to %g epochs is finished, nothing to resume.' % (
  122. total_epochs)
  123. return modeldir, model_artifact
  124. return None, None
  125. def log_model(self, path, opt, epoch, fitness_score, best_model=False):
  126. model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
  127. 'original_url': str(path),
  128. 'epochs_trained': epoch + 1,
  129. 'save period': opt.save_period,
  130. 'project': opt.project,
  131. 'total_epochs': opt.epochs,
  132. 'fitness_score': fitness_score
  133. })
  134. model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
  135. wandb.log_artifact(model_artifact,
  136. aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
  137. print("Saving model artifact on epoch ", epoch + 1)
  138. def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
  139. with open(data_file) as f:
  140. data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
  141. nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
  142. names = {k: v for k, v in enumerate(names)} # to index dictionary
  143. self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
  144. data['train']), names, name='train') if data.get('train') else None
  145. self.val_artifact = self.create_dataset_table(LoadImagesAndLabels(
  146. data['val']), names, name='val') if data.get('val') else None
  147. if data.get('train'):
  148. data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
  149. if data.get('val'):
  150. data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
  151. path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
  152. data.pop('download', None)
  153. with open(path, 'w') as f:
  154. yaml.dump(data, f)
  155. if self.job_type == 'Training': # builds correct artifact pipeline graph
  156. self.wandb_run.use_artifact(self.val_artifact)
  157. self.wandb_run.use_artifact(self.train_artifact)
  158. self.val_artifact.wait()
  159. self.val_table = self.val_artifact.get('val')
  160. self.map_val_table_path()
  161. else:
  162. self.wandb_run.log_artifact(self.train_artifact)
  163. self.wandb_run.log_artifact(self.val_artifact)
  164. return path
  165. def map_val_table_path(self):
  166. self.val_table_map = {}
  167. print("Mapping dataset")
  168. for i, data in enumerate(tqdm(self.val_table.data)):
  169. self.val_table_map[data[3]] = data[0]
  170. def create_dataset_table(self, dataset, class_to_id, name='dataset'):
  171. # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
  172. artifact = wandb.Artifact(name=name, type="dataset")
  173. for img_file in tqdm([dataset.path]) if Path(dataset.path).is_dir() else tqdm(dataset.img_files):
  174. if Path(img_file).is_dir():
  175. artifact.add_dir(img_file, name='data/images')
  176. labels_path = 'labels'.join(dataset.path.rsplit('images', 1))
  177. artifact.add_dir(labels_path, name='data/labels')
  178. else:
  179. artifact.add_file(img_file, name='data/images/' + Path(img_file).name)
  180. label_file = Path(img2label_paths([img_file])[0])
  181. artifact.add_file(str(label_file),
  182. name='data/labels/' + label_file.name) if label_file.exists() else None
  183. table = wandb.Table(columns=["id", "train_image", "Classes", "name"])
  184. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
  185. for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)):
  186. height, width = shapes[0]
  187. labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4))) * torch.Tensor([width, height, width, height])
  188. box_data, img_classes = [], {}
  189. for cls, *xyxy in labels[:, 1:].tolist():
  190. cls = int(cls)
  191. box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
  192. "class_id": cls,
  193. "box_caption": "%s" % (class_to_id[cls]),
  194. "scores": {"acc": 1},
  195. "domain": "pixel"})
  196. img_classes[cls] = class_to_id[cls]
  197. boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
  198. table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes),
  199. Path(paths).name)
  200. artifact.add(table, name)
  201. return artifact
  202. def log_training_progress(self, predn, path, names):
  203. if self.val_table and self.result_table:
  204. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
  205. box_data = []
  206. total_conf = 0
  207. for *xyxy, conf, cls in predn.tolist():
  208. if conf >= 0.25:
  209. box_data.append(
  210. {"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
  211. "class_id": int(cls),
  212. "box_caption": "%s %.3f" % (names[cls], conf),
  213. "scores": {"class_score": conf},
  214. "domain": "pixel"})
  215. total_conf = total_conf + conf
  216. boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
  217. id = self.val_table_map[Path(path).name]
  218. self.result_table.add_data(self.current_epoch,
  219. id,
  220. wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
  221. total_conf / max(1, len(box_data))
  222. )
  223. def log(self, log_dict):
  224. if self.wandb_run:
  225. for key, value in log_dict.items():
  226. self.log_dict[key] = value
  227. def end_epoch(self, best_result=False):
  228. if self.wandb_run:
  229. wandb.log(self.log_dict)
  230. self.log_dict = {}
  231. if self.result_artifact:
  232. train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
  233. self.result_artifact.add(train_results, 'result')
  234. wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch),
  235. ('best' if best_result else '')])
  236. self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
  237. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  238. def finish_run(self):
  239. if self.wandb_run:
  240. if self.log_dict:
  241. wandb.log(self.log_dict)
  242. wandb.run.finish()