Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

306 lines
16KB

  1. import argparse
  2. import json
  3. import os
  4. import shutil
  5. import sys
  6. import torch
  7. import yaml
  8. from datetime import datetime
  9. from pathlib import Path
  10. from tqdm import tqdm
  11. sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
  12. from utils.datasets import LoadImagesAndLabels
  13. from utils.datasets import img2label_paths
  14. from utils.general import colorstr, xywh2xyxy, check_dataset
  15. try:
  16. import wandb
  17. from wandb import init, finish
  18. except ImportError:
  19. wandb = None
  20. WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
  21. def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX):
  22. return from_string[len(prefix):]
  23. def check_wandb_config_file(data_config_file):
  24. wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path
  25. if Path(wandb_config).is_file():
  26. return wandb_config
  27. return data_config_file
  28. def get_run_info(run_path):
  29. run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
  30. run_id = run_path.stem
  31. project = run_path.parent.stem
  32. model_artifact_name = 'run_' + run_id + '_model'
  33. return run_id, project, model_artifact_name
  34. def check_wandb_resume(opt):
  35. process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None
  36. if isinstance(opt.resume, str):
  37. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  38. if opt.global_rank not in [-1, 0]: # For resuming DDP runs
  39. run_id, project, model_artifact_name = get_run_info(opt.resume)
  40. api = wandb.Api()
  41. artifact = api.artifact(project + '/' + model_artifact_name + ':latest')
  42. modeldir = artifact.download()
  43. opt.weights = str(Path(modeldir) / "last.pt")
  44. return True
  45. return None
  46. def process_wandb_config_ddp_mode(opt):
  47. with open(opt.data) as f:
  48. data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
  49. train_dir, val_dir = None, None
  50. if data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX):
  51. api = wandb.Api()
  52. train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias)
  53. train_dir = train_artifact.download()
  54. train_path = Path(train_dir) / 'data/images/'
  55. data_dict['train'] = str(train_path)
  56. if data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX):
  57. api = wandb.Api()
  58. val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias)
  59. val_dir = val_artifact.download()
  60. val_path = Path(val_dir) / 'data/images/'
  61. data_dict['val'] = str(val_path)
  62. if train_dir or val_dir:
  63. ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml')
  64. with open(ddp_data_path, 'w') as f:
  65. yaml.dump(data_dict, f)
  66. opt.data = ddp_data_path
  67. class WandbLogger():
  68. def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
  69. # Pre-training routine --
  70. self.job_type = job_type
  71. self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
  72. # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
  73. if isinstance(opt.resume, str): # checks resume from artifact
  74. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  75. run_id, project, model_artifact_name = get_run_info(opt.resume)
  76. model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
  77. assert wandb, 'install wandb to resume wandb runs'
  78. # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
  79. self.wandb_run = wandb.init(id=run_id, project=project, resume='allow')
  80. opt.resume = model_artifact_name
  81. elif self.wandb:
  82. self.wandb_run = wandb.init(config=opt,
  83. resume="allow",
  84. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  85. name=name,
  86. job_type=job_type,
  87. id=run_id) if not wandb.run else wandb.run
  88. if self.wandb_run:
  89. if self.job_type == 'Training':
  90. if not opt.resume:
  91. wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
  92. # Info useful for resuming from artifacts
  93. self.wandb_run.config.opt = vars(opt)
  94. self.wandb_run.config.data_dict = wandb_data_dict
  95. self.data_dict = self.setup_training(opt, data_dict)
  96. if self.job_type == 'Dataset Creation':
  97. self.data_dict = self.check_and_upload_dataset(opt)
  98. else:
  99. print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
  100. def check_and_upload_dataset(self, opt):
  101. assert wandb, 'Install wandb to upload dataset'
  102. check_dataset(self.data_dict)
  103. config_path = self.log_dataset_artifact(opt.data,
  104. opt.single_cls,
  105. 'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
  106. print("Created dataset config file ", config_path)
  107. with open(config_path) as f:
  108. wandb_data_dict = yaml.load(f, Loader=yaml.SafeLoader)
  109. return wandb_data_dict
  110. def setup_training(self, opt, data_dict):
  111. self.log_dict, self.current_epoch, self.log_imgs = {}, 0, 16 # Logging Constants
  112. self.bbox_interval = opt.bbox_interval
  113. if isinstance(opt.resume, str):
  114. modeldir, _ = self.download_model_artifact(opt)
  115. if modeldir:
  116. self.weights = Path(modeldir) / "last.pt"
  117. config = self.wandb_run.config
  118. opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
  119. self.weights), config.save_period, config.total_batch_size, config.bbox_interval, config.epochs, \
  120. config.opt['hyp']
  121. data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
  122. if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download
  123. self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
  124. opt.artifact_alias)
  125. self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
  126. opt.artifact_alias)
  127. self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None
  128. if self.train_artifact_path is not None:
  129. train_path = Path(self.train_artifact_path) / 'data/images/'
  130. data_dict['train'] = str(train_path)
  131. if self.val_artifact_path is not None:
  132. val_path = Path(self.val_artifact_path) / 'data/images/'
  133. data_dict['val'] = str(val_path)
  134. self.val_table = self.val_artifact.get("val")
  135. self.map_val_table_path()
  136. if self.val_artifact is not None:
  137. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  138. self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
  139. if opt.bbox_interval == -1:
  140. self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
  141. return data_dict
  142. def download_dataset_artifact(self, path, alias):
  143. if path.startswith(WANDB_ARTIFACT_PREFIX):
  144. dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
  145. assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
  146. datadir = dataset_artifact.download()
  147. return datadir, dataset_artifact
  148. return None, None
  149. def download_model_artifact(self, opt):
  150. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  151. model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
  152. assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
  153. modeldir = model_artifact.download()
  154. epochs_trained = model_artifact.metadata.get('epochs_trained')
  155. total_epochs = model_artifact.metadata.get('total_epochs')
  156. assert epochs_trained < total_epochs, 'training to %g epochs is finished, nothing to resume.' % (
  157. total_epochs)
  158. return modeldir, model_artifact
  159. return None, None
  160. def log_model(self, path, opt, epoch, fitness_score, best_model=False):
  161. model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
  162. 'original_url': str(path),
  163. 'epochs_trained': epoch + 1,
  164. 'save period': opt.save_period,
  165. 'project': opt.project,
  166. 'total_epochs': opt.epochs,
  167. 'fitness_score': fitness_score
  168. })
  169. model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
  170. wandb.log_artifact(model_artifact,
  171. aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
  172. print("Saving model artifact on epoch ", epoch + 1)
  173. def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
  174. with open(data_file) as f:
  175. data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
  176. nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
  177. names = {k: v for k, v in enumerate(names)} # to index dictionary
  178. self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
  179. data['train']), names, name='train') if data.get('train') else None
  180. self.val_artifact = self.create_dataset_table(LoadImagesAndLabels(
  181. data['val']), names, name='val') if data.get('val') else None
  182. if data.get('train'):
  183. data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
  184. if data.get('val'):
  185. data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
  186. path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
  187. data.pop('download', None)
  188. with open(path, 'w') as f:
  189. yaml.dump(data, f)
  190. if self.job_type == 'Training': # builds correct artifact pipeline graph
  191. self.wandb_run.use_artifact(self.val_artifact)
  192. self.wandb_run.use_artifact(self.train_artifact)
  193. self.val_artifact.wait()
  194. self.val_table = self.val_artifact.get('val')
  195. self.map_val_table_path()
  196. else:
  197. self.wandb_run.log_artifact(self.train_artifact)
  198. self.wandb_run.log_artifact(self.val_artifact)
  199. return path
  200. def map_val_table_path(self):
  201. self.val_table_map = {}
  202. print("Mapping dataset")
  203. for i, data in enumerate(tqdm(self.val_table.data)):
  204. self.val_table_map[data[3]] = data[0]
  205. def create_dataset_table(self, dataset, class_to_id, name='dataset'):
  206. # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
  207. artifact = wandb.Artifact(name=name, type="dataset")
  208. for img_file in tqdm([dataset.path]) if Path(dataset.path).is_dir() else tqdm(dataset.img_files):
  209. if Path(img_file).is_dir():
  210. artifact.add_dir(img_file, name='data/images')
  211. labels_path = 'labels'.join(dataset.path.rsplit('images', 1))
  212. artifact.add_dir(labels_path, name='data/labels')
  213. else:
  214. artifact.add_file(img_file, name='data/images/' + Path(img_file).name)
  215. label_file = Path(img2label_paths([img_file])[0])
  216. artifact.add_file(str(label_file),
  217. name='data/labels/' + label_file.name) if label_file.exists() else None
  218. table = wandb.Table(columns=["id", "train_image", "Classes", "name"])
  219. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
  220. for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)):
  221. height, width = shapes[0]
  222. labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4))) * torch.Tensor([width, height, width, height])
  223. box_data, img_classes = [], {}
  224. for cls, *xyxy in labels[:, 1:].tolist():
  225. cls = int(cls)
  226. box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
  227. "class_id": cls,
  228. "box_caption": "%s" % (class_to_id[cls]),
  229. "scores": {"acc": 1},
  230. "domain": "pixel"})
  231. img_classes[cls] = class_to_id[cls]
  232. boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
  233. table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes),
  234. Path(paths).name)
  235. artifact.add(table, name)
  236. return artifact
  237. def log_training_progress(self, predn, path, names):
  238. if self.val_table and self.result_table:
  239. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
  240. box_data = []
  241. total_conf = 0
  242. for *xyxy, conf, cls in predn.tolist():
  243. if conf >= 0.25:
  244. box_data.append(
  245. {"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
  246. "class_id": int(cls),
  247. "box_caption": "%s %.3f" % (names[cls], conf),
  248. "scores": {"class_score": conf},
  249. "domain": "pixel"})
  250. total_conf = total_conf + conf
  251. boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
  252. id = self.val_table_map[Path(path).name]
  253. self.result_table.add_data(self.current_epoch,
  254. id,
  255. wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
  256. total_conf / max(1, len(box_data))
  257. )
  258. def log(self, log_dict):
  259. if self.wandb_run:
  260. for key, value in log_dict.items():
  261. self.log_dict[key] = value
  262. def end_epoch(self, best_result=False):
  263. if self.wandb_run:
  264. wandb.log(self.log_dict)
  265. self.log_dict = {}
  266. if self.result_artifact:
  267. train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
  268. self.result_artifact.add(train_results, 'result')
  269. wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch),
  270. ('best' if best_result else '')])
  271. self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
  272. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  273. def finish_run(self):
  274. if self.wandb_run:
  275. if self.log_dict:
  276. wandb.log(self.log_dict)
  277. wandb.run.finish()