You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

371 line
18KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Validate a trained YOLOv5 model accuracy on a custom dataset
  4. Usage:
  5. $ python path/to/val.py --data coco128.yaml --weights yolov5s.pt --img 640
  6. """
  7. import argparse
  8. import json
  9. import os
  10. import sys
  11. from pathlib import Path
  12. from threading import Thread
  13. import numpy as np
  14. import torch
  15. from tqdm import tqdm
  16. FILE = Path(__file__).resolve()
  17. ROOT = FILE.parents[0] # YOLOv5 root directory
  18. if str(ROOT) not in sys.path:
  19. sys.path.append(str(ROOT)) # add ROOT to PATH
  20. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  21. from models.common import DetectMultiBackend
  22. from utils.callbacks import Callbacks
  23. from utils.datasets import create_dataloader
  24. from utils.general import (LOGGER, box_iou, check_dataset, check_img_size, check_requirements, check_yaml,
  25. coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
  26. scale_coords, xywh2xyxy, xyxy2xywh)
  27. from utils.metrics import ConfusionMatrix, ap_per_class
  28. from utils.plots import output_to_target, plot_images, plot_val_study
  29. from utils.torch_utils import select_device, time_sync
  30. def save_one_txt(predn, save_conf, shape, file):
  31. # Save one txt result
  32. gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
  33. for *xyxy, conf, cls in predn.tolist():
  34. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  35. line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
  36. with open(file, 'a') as f:
  37. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  38. def save_one_json(predn, jdict, path, class_map):
  39. # Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
  40. image_id = int(path.stem) if path.stem.isnumeric() else path.stem
  41. box = xyxy2xywh(predn[:, :4]) # xywh
  42. box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
  43. for p, b in zip(predn.tolist(), box.tolist()):
  44. jdict.append({'image_id': image_id,
  45. 'category_id': class_map[int(p[5])],
  46. 'bbox': [round(x, 3) for x in b],
  47. 'score': round(p[4], 5)})
  48. def process_batch(detections, labels, iouv):
  49. """
  50. Return correct predictions matrix. Both sets of boxes are in (x1, y1, x2, y2) format.
  51. Arguments:
  52. detections (Array[N, 6]), x1, y1, x2, y2, conf, class
  53. labels (Array[M, 5]), class, x1, y1, x2, y2
  54. Returns:
  55. correct (Array[N, 10]), for 10 IoU levels
  56. """
  57. correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
  58. iou = box_iou(labels[:, 1:], detections[:, :4])
  59. x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) # IoU above threshold and classes match
  60. if x[0].shape[0]:
  61. matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detection, iou]
  62. if x[0].shape[0] > 1:
  63. matches = matches[matches[:, 2].argsort()[::-1]]
  64. matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
  65. # matches = matches[matches[:, 2].argsort()[::-1]]
  66. matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
  67. matches = torch.Tensor(matches).to(iouv.device)
  68. correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
  69. return correct
  70. @torch.no_grad()
  71. def run(data,
  72. weights=None, # model.pt path(s)
  73. batch_size=32, # batch size
  74. imgsz=640, # inference size (pixels)
  75. conf_thres=0.001, # confidence threshold
  76. iou_thres=0.6, # NMS IoU threshold
  77. task='val', # train, val, test, speed or study
  78. device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  79. workers=8, # max dataloader workers (per RANK in DDP mode)
  80. single_cls=False, # treat as single-class dataset
  81. augment=False, # augmented inference
  82. verbose=False, # verbose output
  83. save_txt=False, # save results to *.txt
  84. save_hybrid=False, # save label+prediction hybrid results to *.txt
  85. save_conf=False, # save confidences in --save-txt labels
  86. save_json=False, # save a COCO-JSON results file
  87. project=ROOT / 'runs/val', # save to project/name
  88. name='exp', # save to project/name
  89. exist_ok=False, # existing project/name ok, do not increment
  90. half=True, # use FP16 half-precision inference
  91. dnn=False, # use OpenCV DNN for ONNX inference
  92. model=None,
  93. dataloader=None,
  94. save_dir=Path(''),
  95. plots=True,
  96. callbacks=Callbacks(),
  97. compute_loss=None,
  98. ):
  99. # Initialize/load model and set device
  100. training = model is not None
  101. if training: # called by train.py
  102. device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model
  103. half &= device.type != 'cpu' # half precision only supported on CUDA
  104. model.half() if half else model.float()
  105. else: # called directly
  106. device = select_device(device, batch_size=batch_size)
  107. # Directories
  108. save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
  109. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  110. # Load model
  111. model = DetectMultiBackend(weights, device=device, dnn=dnn)
  112. stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
  113. imgsz = check_img_size(imgsz, s=stride) # check image size
  114. half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
  115. if pt or jit:
  116. model.model.half() if half else model.model.float()
  117. elif engine:
  118. batch_size = model.batch_size
  119. else:
  120. half = False
  121. batch_size = 1 # export.py models default to batch-size 1
  122. device = torch.device('cpu')
  123. LOGGER.info(f'Forcing --batch-size 1 square inference shape(1,3,{imgsz},{imgsz}) for non-PyTorch backends')
  124. # Data
  125. data = check_dataset(data) # check
  126. # Configure
  127. model.eval()
  128. is_coco = isinstance(data.get('val'), str) and data['val'].endswith('coco/val2017.txt') # COCO dataset
  129. nc = 1 if single_cls else int(data['nc']) # number of classes
  130. iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
  131. niou = iouv.numel()
  132. # Dataloader
  133. if not training:
  134. model.warmup(imgsz=(1, 3, imgsz, imgsz), half=half) # warmup
  135. pad = 0.0 if task == 'speed' else 0.5
  136. task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
  137. dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=pt,
  138. workers=workers, prefix=colorstr(f'{task}: '))[0]
  139. seen = 0
  140. confusion_matrix = ConfusionMatrix(nc=nc)
  141. names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
  142. class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
  143. s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
  144. dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
  145. loss = torch.zeros(3, device=device)
  146. jdict, stats, ap, ap_class = [], [], [], []
  147. pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  148. for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
  149. t1 = time_sync()
  150. if pt or jit or engine:
  151. im = im.to(device, non_blocking=True)
  152. targets = targets.to(device)
  153. im = im.half() if half else im.float() # uint8 to fp16/32
  154. im /= 255 # 0 - 255 to 0.0 - 1.0
  155. nb, _, height, width = im.shape # batch size, channels, height, width
  156. t2 = time_sync()
  157. dt[0] += t2 - t1
  158. # Inference
  159. out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs
  160. dt[1] += time_sync() - t2
  161. # Loss
  162. if compute_loss:
  163. loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls
  164. # NMS
  165. targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
  166. lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
  167. t3 = time_sync()
  168. out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
  169. dt[2] += time_sync() - t3
  170. # Metrics
  171. for si, pred in enumerate(out):
  172. labels = targets[targets[:, 0] == si, 1:]
  173. nl = len(labels)
  174. tcls = labels[:, 0].tolist() if nl else [] # target class
  175. path, shape = Path(paths[si]), shapes[si][0]
  176. seen += 1
  177. if len(pred) == 0:
  178. if nl:
  179. stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
  180. continue
  181. # Predictions
  182. if single_cls:
  183. pred[:, 5] = 0
  184. predn = pred.clone()
  185. scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
  186. # Evaluate
  187. if nl:
  188. tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
  189. scale_coords(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
  190. labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
  191. correct = process_batch(predn, labelsn, iouv)
  192. if plots:
  193. confusion_matrix.process_batch(predn, labelsn)
  194. else:
  195. correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool)
  196. stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # (correct, conf, pcls, tcls)
  197. # Save/log
  198. if save_txt:
  199. save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
  200. if save_json:
  201. save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
  202. callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
  203. # Plot images
  204. if plots and batch_i < 3:
  205. f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels
  206. Thread(target=plot_images, args=(im, targets, paths, f, names), daemon=True).start()
  207. f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
  208. Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start()
  209. # Compute metrics
  210. stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
  211. if len(stats) and stats[0].any():
  212. tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
  213. ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
  214. mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
  215. nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
  216. else:
  217. nt = torch.zeros(1)
  218. # Print results
  219. pf = '%20s' + '%11i' * 2 + '%11.3g' * 4 # print format
  220. LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
  221. # Print results per class
  222. if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
  223. for i, c in enumerate(ap_class):
  224. LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
  225. # Print speeds
  226. t = tuple(x / seen * 1E3 for x in dt) # speeds per image
  227. if not training:
  228. shape = (batch_size, 3, imgsz, imgsz)
  229. LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
  230. # Plots
  231. if plots:
  232. confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
  233. callbacks.run('on_val_end')
  234. # Save JSON
  235. if save_json and len(jdict):
  236. w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
  237. anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json
  238. pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
  239. LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
  240. with open(pred_json, 'w') as f:
  241. json.dump(jdict, f)
  242. try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
  243. check_requirements(['pycocotools'])
  244. from pycocotools.coco import COCO
  245. from pycocotools.cocoeval import COCOeval
  246. anno = COCO(anno_json) # init annotations api
  247. pred = anno.loadRes(pred_json) # init predictions api
  248. eval = COCOeval(anno, pred, 'bbox')
  249. if is_coco:
  250. eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files] # image IDs to evaluate
  251. eval.evaluate()
  252. eval.accumulate()
  253. eval.summarize()
  254. map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
  255. except Exception as e:
  256. LOGGER.info(f'pycocotools unable to run: {e}')
  257. # Return results
  258. model.float() # for training
  259. if not training:
  260. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  261. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
  262. maps = np.zeros(nc) + map
  263. for i, c in enumerate(ap_class):
  264. maps[c] = ap[i]
  265. return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t
  266. def parse_opt():
  267. parser = argparse.ArgumentParser()
  268. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
  269. parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)')
  270. parser.add_argument('--batch-size', type=int, default=32, help='batch size')
  271. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
  272. parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold')
  273. parser.add_argument('--iou-thres', type=float, default=0.6, help='NMS IoU threshold')
  274. parser.add_argument('--task', default='val', help='train, val, test, speed or study')
  275. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  276. parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
  277. parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
  278. parser.add_argument('--augment', action='store_true', help='augmented inference')
  279. parser.add_argument('--verbose', action='store_true', help='report mAP by class')
  280. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  281. parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt')
  282. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  283. parser.add_argument('--save-json', action='store_true', help='save a COCO-JSON results file')
  284. parser.add_argument('--project', default=ROOT / 'runs/val', help='save to project/name')
  285. parser.add_argument('--name', default='exp', help='save to project/name')
  286. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  287. parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
  288. parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
  289. opt = parser.parse_args()
  290. opt.data = check_yaml(opt.data) # check YAML
  291. opt.save_json |= opt.data.endswith('coco.yaml')
  292. opt.save_txt |= opt.save_hybrid
  293. print_args(FILE.stem, opt)
  294. return opt
  295. def main(opt):
  296. check_requirements(requirements=ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
  297. if opt.task in ('train', 'val', 'test'): # run normally
  298. if opt.conf_thres > 0.001: # https://github.com/ultralytics/yolov5/issues/1466
  299. LOGGER.info(f'WARNING: confidence threshold {opt.conf_thres} >> 0.001 will produce invalid mAP values.')
  300. run(**vars(opt))
  301. else:
  302. weights = opt.weights if isinstance(opt.weights, list) else [opt.weights]
  303. opt.half = True # FP16 for fastest results
  304. if opt.task == 'speed': # speed benchmarks
  305. # python val.py --task speed --data coco.yaml --batch 1 --weights yolov5n.pt yolov5s.pt...
  306. opt.conf_thres, opt.iou_thres, opt.save_json = 0.25, 0.45, False
  307. for opt.weights in weights:
  308. run(**vars(opt), plots=False)
  309. elif opt.task == 'study': # speed vs mAP benchmarks
  310. # python val.py --task study --data coco.yaml --iou 0.7 --weights yolov5n.pt yolov5s.pt...
  311. for opt.weights in weights:
  312. f = f'study_{Path(opt.data).stem}_{Path(opt.weights).stem}.txt' # filename to save to
  313. x, y = list(range(256, 1536 + 128, 128)), [] # x axis (image sizes), y axis
  314. for opt.imgsz in x: # img-size
  315. LOGGER.info(f'\nRunning {f} --imgsz {opt.imgsz}...')
  316. r, _, t = run(**vars(opt), plots=False)
  317. y.append(r + t) # results and times
  318. np.savetxt(f, y, fmt='%10.4g') # save
  319. os.system('zip -r study.zip study_*.txt')
  320. plot_val_study(x=x) # plot
  321. if __name__ == "__main__":
  322. opt = parse_opt()
  323. main(opt)