无人机视角的行人小目标检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

527 line
25KB

  1. """Utilities and tools for tracking runs with Weights & Biases."""
  2. import logging
  3. import os
  4. import sys
  5. from contextlib import contextmanager
  6. from pathlib import Path
  7. from typing import Dict
  8. import pkg_resources as pkg
  9. import yaml
  10. from tqdm import tqdm
  11. FILE = Path(__file__).resolve()
  12. ROOT = FILE.parents[3] # YOLOv5 root directory
  13. if str(ROOT) not in sys.path:
  14. sys.path.append(str(ROOT)) # add ROOT to PATH
  15. from utils.datasets import LoadImagesAndLabels, img2label_paths
  16. from utils.general import check_dataset, check_file
  17. try:
  18. import wandb
  19. assert hasattr(wandb, '__version__') # verify package import not local dir
  20. except (ImportError, AssertionError):
  21. wandb = None
  22. RANK = int(os.getenv('RANK', -1))
  23. WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
  24. def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX):
  25. return from_string[len(prefix):]
  26. def check_wandb_config_file(data_config_file):
  27. wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path
  28. if Path(wandb_config).is_file():
  29. return wandb_config
  30. return data_config_file
  31. def check_wandb_dataset(data_file):
  32. is_trainset_wandb_artifact = False
  33. is_valset_wandb_artifact = False
  34. if check_file(data_file) and data_file.endswith('.yaml'):
  35. with open(data_file, errors='ignore') as f:
  36. data_dict = yaml.safe_load(f)
  37. is_trainset_wandb_artifact = (isinstance(data_dict['train'], str) and
  38. data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX))
  39. is_valset_wandb_artifact = (isinstance(data_dict['val'], str) and
  40. data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX))
  41. if is_trainset_wandb_artifact or is_valset_wandb_artifact:
  42. return data_dict
  43. else:
  44. return check_dataset(data_file)
  45. def get_run_info(run_path):
  46. run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
  47. run_id = run_path.stem
  48. project = run_path.parent.stem
  49. entity = run_path.parent.parent.stem
  50. model_artifact_name = 'run_' + run_id + '_model'
  51. return entity, project, run_id, model_artifact_name
  52. def check_wandb_resume(opt):
  53. process_wandb_config_ddp_mode(opt) if RANK not in [-1, 0] else None
  54. if isinstance(opt.resume, str):
  55. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  56. if RANK not in [-1, 0]: # For resuming DDP runs
  57. entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
  58. api = wandb.Api()
  59. artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
  60. modeldir = artifact.download()
  61. opt.weights = str(Path(modeldir) / "last.pt")
  62. return True
  63. return None
  64. def process_wandb_config_ddp_mode(opt):
  65. with open(check_file(opt.data), errors='ignore') as f:
  66. data_dict = yaml.safe_load(f) # data dict
  67. train_dir, val_dir = None, None
  68. if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX):
  69. api = wandb.Api()
  70. train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias)
  71. train_dir = train_artifact.download()
  72. train_path = Path(train_dir) / 'data/images/'
  73. data_dict['train'] = str(train_path)
  74. if isinstance(data_dict['val'], str) and data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX):
  75. api = wandb.Api()
  76. val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias)
  77. val_dir = val_artifact.download()
  78. val_path = Path(val_dir) / 'data/images/'
  79. data_dict['val'] = str(val_path)
  80. if train_dir or val_dir:
  81. ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml')
  82. with open(ddp_data_path, 'w') as f:
  83. yaml.safe_dump(data_dict, f)
  84. opt.data = ddp_data_path
  85. class WandbLogger():
  86. """Log training runs, datasets, models, and predictions to Weights & Biases.
  87. This logger sends information to W&B at wandb.ai. By default, this information
  88. includes hyperparameters, system configuration and metrics, model metrics,
  89. and basic data metrics and analyses.
  90. By providing additional command line arguments to train.py, datasets,
  91. models and predictions can also be logged.
  92. For more on how this logger is used, see the Weights & Biases documentation:
  93. https://docs.wandb.com/guides/integrations/yolov5
  94. """
  95. def __init__(self, opt, run_id=None, job_type='Training'):
  96. """
  97. - Initialize WandbLogger instance
  98. - Upload dataset if opt.upload_dataset is True
  99. - Setup trainig processes if job_type is 'Training'
  100. arguments:
  101. opt (namespace) -- Commandline arguments for this run
  102. run_id (str) -- Run ID of W&B run to be resumed
  103. job_type (str) -- To set the job_type for this run
  104. """
  105. # Pre-training routine --
  106. self.job_type = job_type
  107. self.wandb, self.wandb_run = wandb, None if not wandb else wandb.run
  108. self.val_artifact, self.train_artifact = None, None
  109. self.train_artifact_path, self.val_artifact_path = None, None
  110. self.result_artifact = None
  111. self.val_table, self.result_table = None, None
  112. self.bbox_media_panel_images = []
  113. self.val_table_path_map = None
  114. self.max_imgs_to_log = 16
  115. self.wandb_artifact_data_dict = None
  116. self.data_dict = None
  117. # It's more elegant to stick to 1 wandb.init call,
  118. # but useful config data is overwritten in the WandbLogger's wandb.init call
  119. if isinstance(opt.resume, str): # checks resume from artifact
  120. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  121. entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
  122. model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
  123. assert wandb, 'install wandb to resume wandb runs'
  124. # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
  125. self.wandb_run = wandb.init(id=run_id,
  126. project=project,
  127. entity=entity,
  128. resume='allow',
  129. allow_val_change=True)
  130. opt.resume = model_artifact_name
  131. elif self.wandb:
  132. self.wandb_run = wandb.init(config=opt,
  133. resume="allow",
  134. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  135. entity=opt.entity,
  136. name=opt.name if opt.name != 'exp' else None,
  137. job_type=job_type,
  138. id=run_id,
  139. allow_val_change=True) if not wandb.run else wandb.run
  140. if self.wandb_run:
  141. if self.job_type == 'Training':
  142. if opt.upload_dataset:
  143. if not opt.resume:
  144. self.wandb_artifact_data_dict = self.check_and_upload_dataset(opt)
  145. if opt.resume:
  146. # resume from artifact
  147. if isinstance(opt.resume, str) and opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  148. self.data_dict = dict(self.wandb_run.config.data_dict)
  149. else: # local resume
  150. self.data_dict = check_wandb_dataset(opt.data)
  151. else:
  152. self.data_dict = check_wandb_dataset(opt.data)
  153. self.wandb_artifact_data_dict = self.wandb_artifact_data_dict or self.data_dict
  154. # write data_dict to config. useful for resuming from artifacts. Do this only when not resuming.
  155. self.wandb_run.config.update({'data_dict': self.wandb_artifact_data_dict},
  156. allow_val_change=True)
  157. self.setup_training(opt)
  158. if self.job_type == 'Dataset Creation':
  159. self.data_dict = self.check_and_upload_dataset(opt)
  160. def check_and_upload_dataset(self, opt):
  161. """
  162. Check if the dataset format is compatible and upload it as W&B artifact
  163. arguments:
  164. opt (namespace)-- Commandline arguments for current run
  165. returns:
  166. Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links.
  167. """
  168. assert wandb, 'Install wandb to upload dataset'
  169. config_path = self.log_dataset_artifact(opt.data,
  170. opt.single_cls,
  171. 'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
  172. print("Created dataset config file ", config_path)
  173. with open(config_path, errors='ignore') as f:
  174. wandb_data_dict = yaml.safe_load(f)
  175. return wandb_data_dict
  176. def setup_training(self, opt):
  177. """
  178. Setup the necessary processes for training YOLO models:
  179. - Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX
  180. - Update data_dict, to contain info of previous run if resumed and the paths of dataset artifact if downloaded
  181. - Setup log_dict, initialize bbox_interval
  182. arguments:
  183. opt (namespace) -- commandline arguments for this run
  184. """
  185. self.log_dict, self.current_epoch = {}, 0
  186. self.bbox_interval = opt.bbox_interval
  187. if isinstance(opt.resume, str):
  188. modeldir, _ = self.download_model_artifact(opt)
  189. if modeldir:
  190. self.weights = Path(modeldir) / "last.pt"
  191. config = self.wandb_run.config
  192. opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
  193. self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \
  194. config.hyp
  195. data_dict = self.data_dict
  196. if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download
  197. self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
  198. opt.artifact_alias)
  199. self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
  200. opt.artifact_alias)
  201. if self.train_artifact_path is not None:
  202. train_path = Path(self.train_artifact_path) / 'data/images/'
  203. data_dict['train'] = str(train_path)
  204. if self.val_artifact_path is not None:
  205. val_path = Path(self.val_artifact_path) / 'data/images/'
  206. data_dict['val'] = str(val_path)
  207. if self.val_artifact is not None:
  208. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  209. self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
  210. self.val_table = self.val_artifact.get("val")
  211. if self.val_table_path_map is None:
  212. self.map_val_table_path()
  213. if opt.bbox_interval == -1:
  214. self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
  215. train_from_artifact = self.train_artifact_path is not None and self.val_artifact_path is not None
  216. # Update the the data_dict to point to local artifacts dir
  217. if train_from_artifact:
  218. self.data_dict = data_dict
  219. def download_dataset_artifact(self, path, alias):
  220. """
  221. download the model checkpoint artifact if the path starts with WANDB_ARTIFACT_PREFIX
  222. arguments:
  223. path -- path of the dataset to be used for training
  224. alias (str)-- alias of the artifact to be download/used for training
  225. returns:
  226. (str, wandb.Artifact) -- path of the downladed dataset and it's corresponding artifact object if dataset
  227. is found otherwise returns (None, None)
  228. """
  229. if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):
  230. artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
  231. dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\", "/"))
  232. assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
  233. datadir = dataset_artifact.download()
  234. return datadir, dataset_artifact
  235. return None, None
  236. def download_model_artifact(self, opt):
  237. """
  238. download the model checkpoint artifact if the resume path starts with WANDB_ARTIFACT_PREFIX
  239. arguments:
  240. opt (namespace) -- Commandline arguments for this run
  241. """
  242. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  243. model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
  244. assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
  245. modeldir = model_artifact.download()
  246. epochs_trained = model_artifact.metadata.get('epochs_trained')
  247. total_epochs = model_artifact.metadata.get('total_epochs')
  248. is_finished = total_epochs is None
  249. assert not is_finished, 'training is finished, can only resume incomplete runs.'
  250. return modeldir, model_artifact
  251. return None, None
  252. def log_model(self, path, opt, epoch, fitness_score, best_model=False):
  253. """
  254. Log the model checkpoint as W&B artifact
  255. arguments:
  256. path (Path) -- Path of directory containing the checkpoints
  257. opt (namespace) -- Command line arguments for this run
  258. epoch (int) -- Current epoch number
  259. fitness_score (float) -- fitness score for current epoch
  260. best_model (boolean) -- Boolean representing if the current checkpoint is the best yet.
  261. """
  262. model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
  263. 'original_url': str(path),
  264. 'epochs_trained': epoch + 1,
  265. 'save period': opt.save_period,
  266. 'project': opt.project,
  267. 'total_epochs': opt.epochs,
  268. 'fitness_score': fitness_score
  269. })
  270. model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
  271. wandb.log_artifact(model_artifact,
  272. aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
  273. print("Saving model artifact on epoch ", epoch + 1)
  274. def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
  275. """
  276. Log the dataset as W&B artifact and return the new data file with W&B links
  277. arguments:
  278. data_file (str) -- the .yaml file with information about the dataset like - path, classes etc.
  279. single_class (boolean) -- train multi-class data as single-class
  280. project (str) -- project name. Used to construct the artifact path
  281. overwrite_config (boolean) -- overwrites the data.yaml file if set to true otherwise creates a new
  282. file with _wandb postfix. Eg -> data_wandb.yaml
  283. returns:
  284. the new .yaml file with artifact links. it can be used to start training directly from artifacts
  285. """
  286. self.data_dict = check_dataset(data_file) # parse and check
  287. data = dict(self.data_dict)
  288. nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
  289. names = {k: v for k, v in enumerate(names)} # to index dictionary
  290. self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
  291. data['train'], rect=True, batch_size=1), names, name='train') if data.get('train') else None
  292. self.val_artifact = self.create_dataset_table(LoadImagesAndLabels(
  293. data['val'], rect=True, batch_size=1), names, name='val') if data.get('val') else None
  294. if data.get('train'):
  295. data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
  296. if data.get('val'):
  297. data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
  298. path = Path(data_file).stem
  299. path = (path if overwrite_config else path + '_wandb') + '.yaml' # updated data.yaml path
  300. data.pop('download', None)
  301. data.pop('path', None)
  302. with open(path, 'w') as f:
  303. yaml.safe_dump(data, f)
  304. if self.job_type == 'Training': # builds correct artifact pipeline graph
  305. self.wandb_run.use_artifact(self.val_artifact)
  306. self.wandb_run.use_artifact(self.train_artifact)
  307. self.val_artifact.wait()
  308. self.val_table = self.val_artifact.get('val')
  309. self.map_val_table_path()
  310. else:
  311. self.wandb_run.log_artifact(self.train_artifact)
  312. self.wandb_run.log_artifact(self.val_artifact)
  313. return path
  314. def map_val_table_path(self):
  315. """
  316. Map the validation dataset Table like name of file -> it's id in the W&B Table.
  317. Useful for - referencing artifacts for evaluation.
  318. """
  319. self.val_table_path_map = {}
  320. print("Mapping dataset")
  321. for i, data in enumerate(tqdm(self.val_table.data)):
  322. self.val_table_path_map[data[3]] = data[0]
  323. def create_dataset_table(self, dataset: LoadImagesAndLabels, class_to_id: Dict[int,str], name: str = 'dataset'):
  324. """
  325. Create and return W&B artifact containing W&B Table of the dataset.
  326. arguments:
  327. dataset -- instance of LoadImagesAndLabels class used to iterate over the data to build Table
  328. class_to_id -- hash map that maps class ids to labels
  329. name -- name of the artifact
  330. returns:
  331. dataset artifact to be logged or used
  332. """
  333. # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
  334. artifact = wandb.Artifact(name=name, type="dataset")
  335. img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None
  336. img_files = tqdm(dataset.img_files) if not img_files else img_files
  337. for img_file in img_files:
  338. if Path(img_file).is_dir():
  339. artifact.add_dir(img_file, name='data/images')
  340. labels_path = 'labels'.join(dataset.path.rsplit('images', 1))
  341. artifact.add_dir(labels_path, name='data/labels')
  342. else:
  343. artifact.add_file(img_file, name='data/images/' + Path(img_file).name)
  344. label_file = Path(img2label_paths([img_file])[0])
  345. artifact.add_file(str(label_file),
  346. name='data/labels/' + label_file.name) if label_file.exists() else None
  347. table = wandb.Table(columns=["id", "train_image", "Classes", "name"])
  348. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
  349. for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)):
  350. box_data, img_classes = [], {}
  351. for cls, *xywh in labels[:, 1:].tolist():
  352. cls = int(cls)
  353. box_data.append({"position": {"middle": [xywh[0], xywh[1]], "width": xywh[2], "height": xywh[3]},
  354. "class_id": cls,
  355. "box_caption": "%s" % (class_to_id[cls])})
  356. img_classes[cls] = class_to_id[cls]
  357. boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
  358. table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), list(img_classes.values()),
  359. Path(paths).name)
  360. artifact.add(table, name)
  361. return artifact
  362. def log_training_progress(self, predn, path, names):
  363. """
  364. Build evaluation Table. Uses reference from validation dataset table.
  365. arguments:
  366. predn (list): list of predictions in the native space in the format - [xmin, ymin, xmax, ymax, confidence, class]
  367. path (str): local path of the current evaluation image
  368. names (dict(int, str)): hash map that maps class ids to labels
  369. """
  370. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
  371. box_data = []
  372. total_conf = 0
  373. for *xyxy, conf, cls in predn.tolist():
  374. if conf >= 0.25:
  375. box_data.append(
  376. {"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
  377. "class_id": int(cls),
  378. "box_caption": f"{names[cls]} {conf:.3f}",
  379. "scores": {"class_score": conf},
  380. "domain": "pixel"})
  381. total_conf += conf
  382. boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
  383. id = self.val_table_path_map[Path(path).name]
  384. self.result_table.add_data(self.current_epoch,
  385. id,
  386. self.val_table.data[id][1],
  387. wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
  388. total_conf / max(1, len(box_data))
  389. )
  390. def val_one_image(self, pred, predn, path, names, im):
  391. """
  392. Log validation data for one image. updates the result Table if validation dataset is uploaded and log bbox media panel
  393. arguments:
  394. pred (list): list of scaled predictions in the format - [xmin, ymin, xmax, ymax, confidence, class]
  395. predn (list): list of predictions in the native space - [xmin, ymin, xmax, ymax, confidence, class]
  396. path (str): local path of the current evaluation image
  397. """
  398. if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact
  399. self.log_training_progress(predn, path, names)
  400. if len(self.bbox_media_panel_images) < self.max_imgs_to_log and self.current_epoch > 0:
  401. if self.current_epoch % self.bbox_interval == 0:
  402. box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
  403. "class_id": int(cls),
  404. "box_caption": f"{names[cls]} {conf:.3f}",
  405. "scores": {"class_score": conf},
  406. "domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
  407. boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
  408. self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name))
  409. def log(self, log_dict):
  410. """
  411. save the metrics to the logging dictionary
  412. arguments:
  413. log_dict (Dict) -- metrics/media to be logged in current step
  414. """
  415. if self.wandb_run:
  416. for key, value in log_dict.items():
  417. self.log_dict[key] = value
  418. def end_epoch(self, best_result=False):
  419. """
  420. commit the log_dict, model artifacts and Tables to W&B and flush the log_dict.
  421. arguments:
  422. best_result (boolean): Boolean representing if the result of this evaluation is best or not
  423. """
  424. if self.wandb_run:
  425. with all_logging_disabled():
  426. if self.bbox_media_panel_images:
  427. self.log_dict["BoundingBoxDebugger"] = self.bbox_media_panel_images
  428. wandb.log(self.log_dict)
  429. self.log_dict = {}
  430. self.bbox_media_panel_images = []
  431. if self.result_artifact:
  432. self.result_artifact.add(self.result_table, 'result')
  433. wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch),
  434. ('best' if best_result else '')])
  435. wandb.log({"evaluation": self.result_table})
  436. self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
  437. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  438. def finish_run(self):
  439. """
  440. Log metrics if any and finish the current W&B run
  441. """
  442. if self.wandb_run:
  443. if self.log_dict:
  444. with all_logging_disabled():
  445. wandb.log(self.log_dict)
  446. wandb.run.finish()
  447. @contextmanager
  448. def all_logging_disabled(highest_level=logging.CRITICAL):
  449. """ source - https://gist.github.com/simon-weber/7853144
  450. A context manager that will prevent any logging messages triggered during the body from being processed.
  451. :param highest_level: the maximum logging level in use.
  452. This would only need to be changed if a custom level greater than CRITICAL is defined.
  453. """
  454. previous_level = logging.root.manager.disable
  455. logging.disable(highest_level)
  456. try:
  457. yield
  458. finally:
  459. logging.disable(previous_level)