Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

304 lines
16KB

  1. import json
  2. import sys
  3. from pathlib import Path
  4. import torch
  5. import yaml
  6. from tqdm import tqdm
  7. sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
  8. from utils.datasets import LoadImagesAndLabels
  9. from utils.datasets import img2label_paths
  10. from utils.general import colorstr, xywh2xyxy, check_dataset, check_file
  11. try:
  12. import wandb
  13. from wandb import init, finish
  14. except ImportError:
  15. wandb = None
  16. WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
  17. def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX):
  18. return from_string[len(prefix):]
  19. def check_wandb_config_file(data_config_file):
  20. wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path
  21. if Path(wandb_config).is_file():
  22. return wandb_config
  23. return data_config_file
  24. def get_run_info(run_path):
  25. run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
  26. run_id = run_path.stem
  27. project = run_path.parent.stem
  28. model_artifact_name = 'run_' + run_id + '_model'
  29. return run_id, project, model_artifact_name
  30. def check_wandb_resume(opt):
  31. process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None
  32. if isinstance(opt.resume, str):
  33. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  34. if opt.global_rank not in [-1, 0]: # For resuming DDP runs
  35. run_id, project, model_artifact_name = get_run_info(opt.resume)
  36. api = wandb.Api()
  37. artifact = api.artifact(project + '/' + model_artifact_name + ':latest')
  38. modeldir = artifact.download()
  39. opt.weights = str(Path(modeldir) / "last.pt")
  40. return True
  41. return None
  42. def process_wandb_config_ddp_mode(opt):
  43. with open(check_file(opt.data)) as f:
  44. data_dict = yaml.safe_load(f) # data dict
  45. train_dir, val_dir = None, None
  46. if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX):
  47. api = wandb.Api()
  48. train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias)
  49. train_dir = train_artifact.download()
  50. train_path = Path(train_dir) / 'data/images/'
  51. data_dict['train'] = str(train_path)
  52. if isinstance(data_dict['val'], str) and data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX):
  53. api = wandb.Api()
  54. val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias)
  55. val_dir = val_artifact.download()
  56. val_path = Path(val_dir) / 'data/images/'
  57. data_dict['val'] = str(val_path)
  58. if train_dir or val_dir:
  59. ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml')
  60. with open(ddp_data_path, 'w') as f:
  61. yaml.safe_dump(data_dict, f)
  62. opt.data = ddp_data_path
  63. class WandbLogger():
  64. def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
  65. # Pre-training routine --
  66. self.job_type = job_type
  67. self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
  68. # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
  69. if isinstance(opt.resume, str): # checks resume from artifact
  70. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  71. run_id, project, model_artifact_name = get_run_info(opt.resume)
  72. model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
  73. assert wandb, 'install wandb to resume wandb runs'
  74. # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
  75. self.wandb_run = wandb.init(id=run_id, project=project, resume='allow')
  76. opt.resume = model_artifact_name
  77. elif self.wandb:
  78. self.wandb_run = wandb.init(config=opt,
  79. resume="allow",
  80. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  81. name=name,
  82. job_type=job_type,
  83. id=run_id) if not wandb.run else wandb.run
  84. if self.wandb_run:
  85. if self.job_type == 'Training':
  86. if not opt.resume:
  87. wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
  88. # Info useful for resuming from artifacts
  89. self.wandb_run.config.opt = vars(opt)
  90. self.wandb_run.config.data_dict = wandb_data_dict
  91. self.data_dict = self.setup_training(opt, data_dict)
  92. if self.job_type == 'Dataset Creation':
  93. self.data_dict = self.check_and_upload_dataset(opt)
  94. else:
  95. prefix = colorstr('wandb: ')
  96. print(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
  97. def check_and_upload_dataset(self, opt):
  98. assert wandb, 'Install wandb to upload dataset'
  99. check_dataset(self.data_dict)
  100. config_path = self.log_dataset_artifact(check_file(opt.data),
  101. opt.single_cls,
  102. 'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
  103. print("Created dataset config file ", config_path)
  104. with open(config_path) as f:
  105. wandb_data_dict = yaml.safe_load(f)
  106. return wandb_data_dict
  107. def setup_training(self, opt, data_dict):
  108. self.log_dict, self.current_epoch, self.log_imgs = {}, 0, 16 # Logging Constants
  109. self.bbox_interval = opt.bbox_interval
  110. if isinstance(opt.resume, str):
  111. modeldir, _ = self.download_model_artifact(opt)
  112. if modeldir:
  113. self.weights = Path(modeldir) / "last.pt"
  114. config = self.wandb_run.config
  115. opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
  116. self.weights), config.save_period, config.total_batch_size, config.bbox_interval, config.epochs, \
  117. config.opt['hyp']
  118. data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
  119. if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download
  120. self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
  121. opt.artifact_alias)
  122. self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
  123. opt.artifact_alias)
  124. self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None
  125. if self.train_artifact_path is not None:
  126. train_path = Path(self.train_artifact_path) / 'data/images/'
  127. data_dict['train'] = str(train_path)
  128. if self.val_artifact_path is not None:
  129. val_path = Path(self.val_artifact_path) / 'data/images/'
  130. data_dict['val'] = str(val_path)
  131. self.val_table = self.val_artifact.get("val")
  132. self.map_val_table_path()
  133. if self.val_artifact is not None:
  134. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  135. self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
  136. if opt.bbox_interval == -1:
  137. self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
  138. return data_dict
  139. def download_dataset_artifact(self, path, alias):
  140. if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):
  141. artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
  142. dataset_artifact = wandb.use_artifact(artifact_path.as_posix())
  143. assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
  144. datadir = dataset_artifact.download()
  145. return datadir, dataset_artifact
  146. return None, None
  147. def download_model_artifact(self, opt):
  148. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  149. model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
  150. assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
  151. modeldir = model_artifact.download()
  152. epochs_trained = model_artifact.metadata.get('epochs_trained')
  153. total_epochs = model_artifact.metadata.get('total_epochs')
  154. assert epochs_trained < total_epochs, 'training to %g epochs is finished, nothing to resume.' % (
  155. total_epochs)
  156. return modeldir, model_artifact
  157. return None, None
  158. def log_model(self, path, opt, epoch, fitness_score, best_model=False):
  159. model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
  160. 'original_url': str(path),
  161. 'epochs_trained': epoch + 1,
  162. 'save period': opt.save_period,
  163. 'project': opt.project,
  164. 'total_epochs': opt.epochs,
  165. 'fitness_score': fitness_score
  166. })
  167. model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
  168. wandb.log_artifact(model_artifact,
  169. aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
  170. print("Saving model artifact on epoch ", epoch + 1)
  171. def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
  172. with open(data_file) as f:
  173. data = yaml.safe_load(f) # data dict
  174. nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
  175. names = {k: v for k, v in enumerate(names)} # to index dictionary
  176. self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
  177. data['train'], rect=True, batch_size=1), names, name='train') if data.get('train') else None
  178. self.val_artifact = self.create_dataset_table(LoadImagesAndLabels(
  179. data['val'], rect=True, batch_size=1), names, name='val') if data.get('val') else None
  180. if data.get('train'):
  181. data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
  182. if data.get('val'):
  183. data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
  184. path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
  185. data.pop('download', None)
  186. with open(path, 'w') as f:
  187. yaml.safe_dump(data, f)
  188. if self.job_type == 'Training': # builds correct artifact pipeline graph
  189. self.wandb_run.use_artifact(self.val_artifact)
  190. self.wandb_run.use_artifact(self.train_artifact)
  191. self.val_artifact.wait()
  192. self.val_table = self.val_artifact.get('val')
  193. self.map_val_table_path()
  194. else:
  195. self.wandb_run.log_artifact(self.train_artifact)
  196. self.wandb_run.log_artifact(self.val_artifact)
  197. return path
  198. def map_val_table_path(self):
  199. self.val_table_map = {}
  200. print("Mapping dataset")
  201. for i, data in enumerate(tqdm(self.val_table.data)):
  202. self.val_table_map[data[3]] = data[0]
  203. def create_dataset_table(self, dataset, class_to_id, name='dataset'):
  204. # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
  205. artifact = wandb.Artifact(name=name, type="dataset")
  206. img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None
  207. img_files = tqdm(dataset.img_files) if not img_files else img_files
  208. for img_file in img_files:
  209. if Path(img_file).is_dir():
  210. artifact.add_dir(img_file, name='data/images')
  211. labels_path = 'labels'.join(dataset.path.rsplit('images', 1))
  212. artifact.add_dir(labels_path, name='data/labels')
  213. else:
  214. artifact.add_file(img_file, name='data/images/' + Path(img_file).name)
  215. label_file = Path(img2label_paths([img_file])[0])
  216. artifact.add_file(str(label_file),
  217. name='data/labels/' + label_file.name) if label_file.exists() else None
  218. table = wandb.Table(columns=["id", "train_image", "Classes", "name"])
  219. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
  220. for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)):
  221. box_data, img_classes = [], {}
  222. for cls, *xywh in labels[:, 1:].tolist():
  223. cls = int(cls)
  224. box_data.append({"position": {"middle": [xywh[0], xywh[1]], "width": xywh[2], "height": xywh[3]},
  225. "class_id": cls,
  226. "box_caption": "%s" % (class_to_id[cls])})
  227. img_classes[cls] = class_to_id[cls]
  228. boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
  229. table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes),
  230. Path(paths).name)
  231. artifact.add(table, name)
  232. return artifact
  233. def log_training_progress(self, predn, path, names):
  234. if self.val_table and self.result_table:
  235. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
  236. box_data = []
  237. total_conf = 0
  238. for *xyxy, conf, cls in predn.tolist():
  239. if conf >= 0.25:
  240. box_data.append(
  241. {"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
  242. "class_id": int(cls),
  243. "box_caption": "%s %.3f" % (names[cls], conf),
  244. "scores": {"class_score": conf},
  245. "domain": "pixel"})
  246. total_conf = total_conf + conf
  247. boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
  248. id = self.val_table_map[Path(path).name]
  249. self.result_table.add_data(self.current_epoch,
  250. id,
  251. wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
  252. total_conf / max(1, len(box_data))
  253. )
  254. def log(self, log_dict):
  255. if self.wandb_run:
  256. for key, value in log_dict.items():
  257. self.log_dict[key] = value
  258. def end_epoch(self, best_result=False):
  259. if self.wandb_run:
  260. wandb.log(self.log_dict)
  261. self.log_dict = {}
  262. if self.result_artifact:
  263. train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
  264. self.result_artifact.add(train_results, 'result')
  265. wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch),
  266. ('best' if best_result else '')])
  267. self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
  268. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  269. def finish_run(self):
  270. if self.wandb_run:
  271. if self.log_dict:
  272. wandb.log(self.log_dict)
  273. wandb.run.finish()