選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

307 行
14KB

  1. import argparse
  2. import glob
  3. import json
  4. import os
  5. import shutil
  6. from pathlib import Path
  7. import numpy as np
  8. import torch
  9. import yaml
  10. from tqdm import tqdm
  11. from models.experimental import attempt_load
  12. from utils.datasets import create_dataloader
  13. from utils.general import (
  14. coco80_to_coco91_class, check_dataset, check_file, check_img_size, compute_loss, non_max_suppression, scale_coords,
  15. xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, ap_per_class, set_logging)
  16. from utils.torch_utils import select_device, time_synchronized
  17. def test(data,
  18. weights=None,
  19. batch_size=16,
  20. imgsz=640,
  21. conf_thres=0.001,
  22. iou_thres=0.6, # for NMS
  23. save_json=False,
  24. single_cls=False,
  25. augment=False,
  26. verbose=False,
  27. model=None,
  28. dataloader=None,
  29. save_dir=Path(''), # for saving images
  30. save_txt=False, # for auto-labelling
  31. save_conf=False,
  32. plots=True):
  33. # Initialize/load model and set device
  34. training = model is not None
  35. if training: # called by train.py
  36. device = next(model.parameters()).device # get model device
  37. else: # called directly
  38. set_logging()
  39. device = select_device(opt.device, batch_size=batch_size)
  40. save_txt = opt.save_txt # save *.txt labels
  41. # Remove previous
  42. if os.path.exists(save_dir):
  43. shutil.rmtree(save_dir) # delete dir
  44. os.makedirs(save_dir) # make new dir
  45. if save_txt:
  46. out = save_dir / 'autolabels'
  47. if os.path.exists(out):
  48. shutil.rmtree(out) # delete dir
  49. os.makedirs(out) # make new dir
  50. # Load model
  51. model = attempt_load(weights, map_location=device) # load FP32 model
  52. imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
  53. # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
  54. # if device.type != 'cpu' and torch.cuda.device_count() > 1:
  55. # model = nn.DataParallel(model)
  56. # Half
  57. half = device.type != 'cpu' # half precision only supported on CUDA
  58. if half:
  59. model.half()
  60. # Configure
  61. model.eval()
  62. with open(data) as f:
  63. data = yaml.load(f, Loader=yaml.FullLoader) # model dict
  64. check_dataset(data) # check
  65. nc = 1 if single_cls else int(data['nc']) # number of classes
  66. iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
  67. niou = iouv.numel()
  68. # Dataloader
  69. if not training:
  70. img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
  71. _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
  72. path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
  73. dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt,
  74. hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0]
  75. seen = 0
  76. names = model.names if hasattr(model, 'names') else model.module.names
  77. coco91class = coco80_to_coco91_class()
  78. s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
  79. p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
  80. loss = torch.zeros(3, device=device)
  81. jdict, stats, ap, ap_class = [], [], [], []
  82. for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
  83. img = img.to(device, non_blocking=True)
  84. img = img.half() if half else img.float() # uint8 to fp16/32
  85. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  86. targets = targets.to(device)
  87. nb, _, height, width = img.shape # batch size, channels, height, width
  88. whwh = torch.Tensor([width, height, width, height]).to(device)
  89. # Disable gradients
  90. with torch.no_grad():
  91. # Run model
  92. t = time_synchronized()
  93. inf_out, train_out = model(img, augment=augment) # inference and training outputs
  94. t0 += time_synchronized() - t
  95. # Compute loss
  96. if training: # if model has loss hyperparameters
  97. loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls
  98. # Run NMS
  99. t = time_synchronized()
  100. output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
  101. t1 += time_synchronized() - t
  102. # Statistics per image
  103. for si, pred in enumerate(output):
  104. labels = targets[targets[:, 0] == si, 1:]
  105. nl = len(labels)
  106. tcls = labels[:, 0].tolist() if nl else [] # target class
  107. seen += 1
  108. if pred is None:
  109. if nl:
  110. stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
  111. continue
  112. # Append to text file
  113. if save_txt:
  114. gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0]] # normalization gain whwh
  115. x = pred.clone()
  116. x[:, :4] = scale_coords(img[si].shape[1:], x[:, :4], shapes[si][0], shapes[si][1]) # to original
  117. for *xyxy, conf, cls in x:
  118. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  119. line = (cls, conf, *xywh) if save_conf else (cls, *xywh) # label format
  120. with open(str(out / Path(paths[si]).stem) + '.txt', 'a') as f:
  121. f.write(('%g ' * len(line) + '\n') % line)
  122. # Clip boxes to image bounds
  123. clip_coords(pred, (height, width))
  124. # Append to pycocotools JSON dictionary
  125. if save_json:
  126. # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
  127. image_id = Path(paths[si]).stem
  128. box = pred[:, :4].clone() # xyxy
  129. scale_coords(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape
  130. box = xyxy2xywh(box) # xywh
  131. box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
  132. for p, b in zip(pred.tolist(), box.tolist()):
  133. jdict.append({'image_id': int(image_id) if image_id.isnumeric() else image_id,
  134. 'category_id': coco91class[int(p[5])],
  135. 'bbox': [round(x, 3) for x in b],
  136. 'score': round(p[4], 5)})
  137. # Assign all predictions as incorrect
  138. correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)
  139. if nl:
  140. detected = [] # target indices
  141. tcls_tensor = labels[:, 0]
  142. # target boxes
  143. tbox = xywh2xyxy(labels[:, 1:5]) * whwh
  144. # Per target class
  145. for cls in torch.unique(tcls_tensor):
  146. ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices
  147. pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # target indices
  148. # Search for detections
  149. if pi.shape[0]:
  150. # Prediction to target ious
  151. ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
  152. # Append detections
  153. detected_set = set()
  154. for j in (ious > iouv[0]).nonzero(as_tuple=False):
  155. d = ti[i[j]] # detected target
  156. if d.item() not in detected_set:
  157. detected_set.add(d.item())
  158. detected.append(d)
  159. correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
  160. if len(detected) == nl: # all targets already located in image
  161. break
  162. # Append statistics (correct, conf, pcls, tcls)
  163. stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
  164. # Plot images
  165. if plots and batch_i < 1:
  166. f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename
  167. plot_images(img, targets, paths, str(f), names) # ground truth
  168. f = save_dir / ('test_batch%g_pred.jpg' % batch_i)
  169. plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
  170. # Compute statistics
  171. stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
  172. if len(stats) and stats[0].any():
  173. p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, fname=save_dir / 'precision-recall_curve.png')
  174. p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95]
  175. mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
  176. nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
  177. else:
  178. nt = torch.zeros(1)
  179. # Print results
  180. pf = '%20s' + '%12.3g' * 6 # print format
  181. print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
  182. # Print results per class
  183. if verbose and nc > 1 and len(stats):
  184. for i, c in enumerate(ap_class):
  185. print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
  186. # Print speeds
  187. t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + (imgsz, imgsz, batch_size) # tuple
  188. if not training:
  189. print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)
  190. # Save JSON
  191. if save_json and len(jdict):
  192. f = 'detections_val2017_%s_results.json' % \
  193. (weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename
  194. print('\nCOCO mAP with pycocotools... saving %s...' % f)
  195. with open(f, 'w') as file:
  196. json.dump(jdict, file)
  197. try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
  198. from pycocotools.coco import COCO
  199. from pycocotools.cocoeval import COCOeval
  200. imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files]
  201. cocoGt = COCO(glob.glob('../coco/annotations/instances_val*.json')[0]) # initialize COCO ground truth api
  202. cocoDt = cocoGt.loadRes(f) # initialize COCO pred api
  203. cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
  204. cocoEval.params.imgIds = imgIds # image IDs to evaluate
  205. cocoEval.evaluate()
  206. cocoEval.accumulate()
  207. cocoEval.summarize()
  208. map, map50 = cocoEval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
  209. except Exception as e:
  210. print('ERROR: pycocotools unable to run: %s' % e)
  211. # Return results
  212. model.float() # for training
  213. maps = np.zeros(nc) + map
  214. for i, c in enumerate(ap_class):
  215. maps[c] = ap[i]
  216. return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t
  217. if __name__ == '__main__':
  218. parser = argparse.ArgumentParser(prog='test.py')
  219. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  220. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
  221. parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
  222. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  223. parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
  224. parser.add_argument('--iou-thres', type=float, default=0.65, help='IOU threshold for NMS')
  225. parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file')
  226. parser.add_argument('--task', default='val', help="'val', 'test', 'study'")
  227. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  228. parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
  229. parser.add_argument('--augment', action='store_true', help='augmented inference')
  230. parser.add_argument('--verbose', action='store_true', help='report mAP by class')
  231. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  232. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  233. parser.add_argument('--save-dir', type=str, default='runs/test', help='directory to save results')
  234. opt = parser.parse_args()
  235. opt.save_json |= opt.data.endswith('coco.yaml')
  236. opt.data = check_file(opt.data) # check file
  237. print(opt)
  238. if opt.task in ['val', 'test']: # run normally
  239. test(opt.data,
  240. opt.weights,
  241. opt.batch_size,
  242. opt.img_size,
  243. opt.conf_thres,
  244. opt.iou_thres,
  245. opt.save_json,
  246. opt.single_cls,
  247. opt.augment,
  248. opt.verbose,
  249. save_dir=Path(opt.save_dir),
  250. save_txt=opt.save_txt,
  251. save_conf=opt.save_conf,
  252. )
  253. print('Results saved to %s' % opt.save_dir)
  254. elif opt.task == 'study': # run over a range of settings and save/plot
  255. for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  256. f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
  257. x = list(range(320, 800, 64)) # x axis
  258. y = [] # y axis
  259. for i in x: # img-size
  260. print('\nRunning %s point %s...' % (f, i))
  261. r, _, t = test(opt.data, weights, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json)
  262. y.append(r + t) # results and times
  263. np.savetxt(f, y, fmt='%10.4g') # save
  264. os.system('zip -r study.zip study_*.txt')
  265. # utils.general.plot_study_txt(f, x) # plot