You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

293 line
13KB

  1. import argparse
  2. import glob
  3. import json
  4. import os
  5. import shutil
  6. from pathlib import Path
  7. import numpy as np
  8. import torch
  9. import yaml
  10. from tqdm import tqdm
  11. from models.experimental import attempt_load
  12. from utils.datasets import create_dataloader
  13. from utils.general import (
  14. coco80_to_coco91_class, check_dataset, check_file, check_img_size, compute_loss, non_max_suppression,
  15. scale_coords, xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, ap_per_class)
  16. from utils.torch_utils import select_device, time_synchronized
  17. def test(data,
  18. weights=None,
  19. batch_size=16,
  20. imgsz=640,
  21. conf_thres=0.001,
  22. iou_thres=0.6, # for NMS
  23. save_json=False,
  24. single_cls=False,
  25. augment=False,
  26. verbose=False,
  27. model=None,
  28. dataloader=None,
  29. save_dir='',
  30. merge=False,
  31. save_txt=False):
  32. # Initialize/load model and set device
  33. training = model is not None
  34. if training: # called by train.py
  35. device = next(model.parameters()).device # get model device
  36. else: # called directly
  37. device = select_device(opt.device, batch_size=batch_size)
  38. merge, save_txt = opt.merge, opt.save_txt # use Merge NMS, save *.txt labels
  39. if save_txt:
  40. out = Path('inference/output')
  41. if os.path.exists(out):
  42. shutil.rmtree(out) # delete output folder
  43. os.makedirs(out) # make new output folder
  44. # Remove previous
  45. for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')):
  46. os.remove(f)
  47. # Load model
  48. model = attempt_load(weights, map_location=device) # load FP32 model
  49. imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
  50. # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
  51. # if device.type != 'cpu' and torch.cuda.device_count() > 1:
  52. # model = nn.DataParallel(model)
  53. # Half
  54. half = device.type != 'cpu' # half precision only supported on CUDA
  55. if half:
  56. model.half()
  57. # Configure
  58. model.eval()
  59. with open(data) as f:
  60. data = yaml.load(f, Loader=yaml.FullLoader) # model dict
  61. check_dataset(data) # check
  62. nc = 1 if single_cls else int(data['nc']) # number of classes
  63. iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
  64. niou = iouv.numel()
  65. # Dataloader
  66. if not training:
  67. img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
  68. _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
  69. path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
  70. dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt,
  71. hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0]
  72. seen = 0
  73. names = model.names if hasattr(model, 'names') else model.module.names
  74. coco91class = coco80_to_coco91_class()
  75. s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
  76. p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
  77. loss = torch.zeros(3, device=device)
  78. jdict, stats, ap, ap_class = [], [], [], []
  79. for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
  80. img = img.to(device, non_blocking=True)
  81. img = img.half() if half else img.float() # uint8 to fp16/32
  82. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  83. targets = targets.to(device)
  84. nb, _, height, width = img.shape # batch size, channels, height, width
  85. whwh = torch.Tensor([width, height, width, height]).to(device)
  86. # Disable gradients
  87. with torch.no_grad():
  88. # Run model
  89. t = time_synchronized()
  90. inf_out, train_out = model(img, augment=augment) # inference and training outputs
  91. t0 += time_synchronized() - t
  92. # Compute loss
  93. if training: # if model has loss hyperparameters
  94. loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls
  95. # Run NMS
  96. t = time_synchronized()
  97. output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
  98. t1 += time_synchronized() - t
  99. # Statistics per image
  100. for si, pred in enumerate(output):
  101. labels = targets[targets[:, 0] == si, 1:]
  102. nl = len(labels)
  103. tcls = labels[:, 0].tolist() if nl else [] # target class
  104. seen += 1
  105. if pred is None:
  106. if nl:
  107. stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
  108. continue
  109. # Append to text file
  110. if save_txt:
  111. gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0]] # normalization gain whwh
  112. txt_path = str(out / Path(paths[si]).stem)
  113. pred[:, :4] = scale_coords(img[si].shape[1:], pred[:, :4], shapes[si][0], shapes[si][1]) # to original
  114. for *xyxy, conf, cls in pred:
  115. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  116. with open(txt_path + '.txt', 'a') as f:
  117. f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
  118. # Clip boxes to image bounds
  119. clip_coords(pred, (height, width))
  120. # Append to pycocotools JSON dictionary
  121. if save_json:
  122. # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
  123. image_id = Path(paths[si]).stem
  124. box = pred[:, :4].clone() # xyxy
  125. scale_coords(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape
  126. box = xyxy2xywh(box) # xywh
  127. box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
  128. for p, b in zip(pred.tolist(), box.tolist()):
  129. jdict.append({'image_id': int(image_id) if image_id.isnumeric() else image_id,
  130. 'category_id': coco91class[int(p[5])],
  131. 'bbox': [round(x, 3) for x in b],
  132. 'score': round(p[4], 5)})
  133. # Assign all predictions as incorrect
  134. correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)
  135. if nl:
  136. detected = [] # target indices
  137. tcls_tensor = labels[:, 0]
  138. # target boxes
  139. tbox = xywh2xyxy(labels[:, 1:5]) * whwh
  140. # Per target class
  141. for cls in torch.unique(tcls_tensor):
  142. ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices
  143. pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # target indices
  144. # Search for detections
  145. if pi.shape[0]:
  146. # Prediction to target ious
  147. ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
  148. # Append detections
  149. for j in (ious > iouv[0]).nonzero(as_tuple=False):
  150. d = ti[i[j]] # detected target
  151. if d not in detected:
  152. detected.append(d)
  153. correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
  154. if len(detected) == nl: # all targets already located in image
  155. break
  156. # Append statistics (correct, conf, pcls, tcls)
  157. stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
  158. # Plot images
  159. if batch_i < 1:
  160. f = Path(save_dir) / ('test_batch%g_gt.jpg' % batch_i) # filename
  161. plot_images(img, targets, paths, str(f), names) # ground truth
  162. f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i)
  163. plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
  164. # Compute statistics
  165. stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
  166. if len(stats) and stats[0].any():
  167. p, r, ap, f1, ap_class = ap_per_class(*stats)
  168. p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95]
  169. mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
  170. nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
  171. else:
  172. nt = torch.zeros(1)
  173. # Print results
  174. pf = '%20s' + '%12.3g' * 6 # print format
  175. print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
  176. # Print results per class
  177. if verbose and nc > 1 and len(stats):
  178. for i, c in enumerate(ap_class):
  179. print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
  180. # Print speeds
  181. t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + (imgsz, imgsz, batch_size) # tuple
  182. if not training:
  183. print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)
  184. # Save JSON
  185. if save_json and len(jdict):
  186. f = 'detections_val2017_%s_results.json' % \
  187. (weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename
  188. print('\nCOCO mAP with pycocotools... saving %s...' % f)
  189. with open(f, 'w') as file:
  190. json.dump(jdict, file)
  191. try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
  192. from pycocotools.coco import COCO
  193. from pycocotools.cocoeval import COCOeval
  194. imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files]
  195. cocoGt = COCO(glob.glob('../coco/annotations/instances_val*.json')[0]) # initialize COCO ground truth api
  196. cocoDt = cocoGt.loadRes(f) # initialize COCO pred api
  197. cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
  198. cocoEval.params.imgIds = imgIds # image IDs to evaluate
  199. cocoEval.evaluate()
  200. cocoEval.accumulate()
  201. cocoEval.summarize()
  202. map, map50 = cocoEval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
  203. except Exception as e:
  204. print('ERROR: pycocotools unable to run: %s' % e)
  205. # Return results
  206. model.float() # for training
  207. maps = np.zeros(nc) + map
  208. for i, c in enumerate(ap_class):
  209. maps[c] = ap[i]
  210. return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t
  211. if __name__ == '__main__':
  212. parser = argparse.ArgumentParser(prog='test.py')
  213. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  214. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
  215. parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
  216. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  217. parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
  218. parser.add_argument('--iou-thres', type=float, default=0.65, help='IOU threshold for NMS')
  219. parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file')
  220. parser.add_argument('--task', default='val', help="'val', 'test', 'study'")
  221. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  222. parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
  223. parser.add_argument('--augment', action='store_true', help='augmented inference')
  224. parser.add_argument('--merge', action='store_true', help='use Merge NMS')
  225. parser.add_argument('--verbose', action='store_true', help='report mAP by class')
  226. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  227. opt = parser.parse_args()
  228. opt.save_json |= opt.data.endswith('coco.yaml')
  229. opt.data = check_file(opt.data) # check file
  230. print(opt)
  231. if opt.task in ['val', 'test']: # run normally
  232. test(opt.data,
  233. opt.weights,
  234. opt.batch_size,
  235. opt.img_size,
  236. opt.conf_thres,
  237. opt.iou_thres,
  238. opt.save_json,
  239. opt.single_cls,
  240. opt.augment,
  241. opt.verbose)
  242. elif opt.task == 'study': # run over a range of settings and save/plot
  243. for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
  244. f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
  245. x = list(range(352, 832, 64)) # x axis
  246. y = [] # y axis
  247. for i in x: # img-size
  248. print('\nRunning %s point %s...' % (f, i))
  249. r, _, t = test(opt.data, weights, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json)
  250. y.append(r + t) # results and times
  251. np.savetxt(f, y, fmt='%10.4g') # save
  252. os.system('zip -r study.zip study_*.txt')
  253. # plot_study_txt(f, x) # plot