You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

310 lines
14KB

  1. import argparse
  2. import glob
  3. import json
  4. import os
  5. import shutil
  6. from pathlib import Path
  7. import numpy as np
  8. import torch
  9. import yaml
  10. from tqdm import tqdm
  11. from models.experimental import attempt_load
  12. from utils.datasets import create_dataloader
  13. from utils.general import (
  14. coco80_to_coco91_class, check_dataset, check_file, check_img_size, compute_loss, non_max_suppression, scale_coords,
  15. xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, ap_per_class, set_logging)
  16. from utils.torch_utils import select_device, time_synchronized
  17. from sotabencheval.object_detection import COCOEvaluator
  18. from sotabencheval.utils import is_server
  19. DATA_ROOT = './.data/vision/coco' if is_server() else '../coco' # sotabench data dir
  20. def test(data,
  21. weights=None,
  22. batch_size=16,
  23. imgsz=640,
  24. conf_thres=0.001,
  25. iou_thres=0.6, # for NMS
  26. save_json=False,
  27. single_cls=False,
  28. augment=False,
  29. verbose=False,
  30. model=None,
  31. dataloader=None,
  32. save_dir='',
  33. merge=False,
  34. save_txt=False):
  35. # Initialize/load model and set device
  36. training = model is not None
  37. if training: # called by train.py
  38. device = next(model.parameters()).device # get model device
  39. else: # called directly
  40. set_logging()
  41. device = select_device(opt.device, batch_size=batch_size)
  42. merge, save_txt = opt.merge, opt.save_txt # use Merge NMS, save *.txt labels
  43. if save_txt:
  44. out = Path('inference/output')
  45. if os.path.exists(out):
  46. shutil.rmtree(out) # delete output folder
  47. os.makedirs(out) # make new output folder
  48. # Remove previous
  49. for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')):
  50. os.remove(f)
  51. # Load model
  52. model = attempt_load(weights, map_location=device) # load FP32 model
  53. imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
  54. # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
  55. # if device.type != 'cpu' and torch.cuda.device_count() > 1:
  56. # model = nn.DataParallel(model)
  57. # Half
  58. half = device.type != 'cpu' # half precision only supported on CUDA
  59. if half:
  60. model.half()
  61. # Configure
  62. model.eval()
  63. with open(data) as f:
  64. data = yaml.load(f, Loader=yaml.FullLoader) # model dict
  65. check_dataset(data) # check
  66. nc = 1 if single_cls else int(data['nc']) # number of classes
  67. iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
  68. niou = iouv.numel()
  69. # Dataloader
  70. if not training:
  71. img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
  72. _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
  73. path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
  74. dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt,
  75. hyp=None, augment=False, cache=True, pad=0.5, rect=True)[0]
  76. seen = 0
  77. names = model.names if hasattr(model, 'names') else model.module.names
  78. coco91class = coco80_to_coco91_class()
  79. s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
  80. p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
  81. loss = torch.zeros(3, device=device)
  82. jdict, stats, ap, ap_class = [], [], [], []
  83. evaluator = COCOEvaluator(root=DATA_ROOT, model_name=opt.weights.replace('.pt', ''))
  84. for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
  85. img = img.to(device, non_blocking=True)
  86. img = img.half() if half else img.float() # uint8 to fp16/32
  87. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  88. targets = targets.to(device)
  89. nb, _, height, width = img.shape # batch size, channels, height, width
  90. whwh = torch.Tensor([width, height, width, height]).to(device)
  91. # Disable gradients
  92. with torch.no_grad():
  93. # Run model
  94. t = time_synchronized()
  95. inf_out, train_out = model(img, augment=augment) # inference and training outputs
  96. t0 += time_synchronized() - t
  97. # Compute loss
  98. if training: # if model has loss hyperparameters
  99. loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls
  100. # Run NMS
  101. t = time_synchronized()
  102. output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
  103. t1 += time_synchronized() - t
  104. # Statistics per image
  105. for si, pred in enumerate(output):
  106. labels = targets[targets[:, 0] == si, 1:]
  107. nl = len(labels)
  108. tcls = labels[:, 0].tolist() if nl else [] # target class
  109. seen += 1
  110. if pred is None:
  111. if nl:
  112. stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
  113. continue
  114. # Append to text file
  115. if save_txt:
  116. gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0]] # normalization gain whwh
  117. x = pred.clone()
  118. x[:, :4] = scale_coords(img[si].shape[1:], x[:, :4], shapes[si][0], shapes[si][1]) # to original
  119. for *xyxy, conf, cls in x:
  120. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  121. with open(str(out / Path(paths[si]).stem) + '.txt', 'a') as f:
  122. f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
  123. # Clip boxes to image bounds
  124. clip_coords(pred, (height, width))
  125. # Append to pycocotools JSON dictionary
  126. if save_json:
  127. # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
  128. image_id = Path(paths[si]).stem
  129. box = pred[:, :4].clone() # xyxy
  130. scale_coords(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape
  131. box = xyxy2xywh(box) # xywh
  132. box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
  133. for p, b in zip(pred.tolist(), box.tolist()):
  134. result = {'image_id': int(image_id) if image_id.isnumeric() else image_id,
  135. 'category_id': coco91class[int(p[5])],
  136. 'bbox': [round(x, 3) for x in b],
  137. 'score': round(p[4], 5)}
  138. jdict.append(result)
  139. #evaluator.add([result])
  140. #if evaluator.cache_exists:
  141. # break
  142. # # Assign all predictions as incorrect
  143. # correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)
  144. # if nl:
  145. # detected = [] # target indices
  146. # tcls_tensor = labels[:, 0]
  147. #
  148. # # target boxes
  149. # tbox = xywh2xyxy(labels[:, 1:5]) * whwh
  150. #
  151. # # Per target class
  152. # for cls in torch.unique(tcls_tensor):
  153. # ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices
  154. # pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # target indices
  155. #
  156. # # Search for detections
  157. # if pi.shape[0]:
  158. # # Prediction to target ious
  159. # ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
  160. #
  161. # # Append detections
  162. # detected_set = set()
  163. # for j in (ious > iouv[0]).nonzero(as_tuple=False):
  164. # d = ti[i[j]] # detected target
  165. # if d.item() not in detected_set:
  166. # detected_set.add(d.item())
  167. # detected.append(d)
  168. # correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
  169. # if len(detected) == nl: # all targets already located in image
  170. # break
  171. #
  172. # # Append statistics (correct, conf, pcls, tcls)
  173. # stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
  174. # # Plot images
  175. # if batch_i < 1:
  176. # f = Path(save_dir) / ('test_batch%g_gt.jpg' % batch_i) # filename
  177. # plot_images(img, targets, paths, str(f), names) # ground truth
  178. # f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i)
  179. # plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
  180. evaluator.add(jdict)
  181. evaluator.save()
  182. # # Compute statistics
  183. # stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
  184. # if len(stats) and stats[0].any():
  185. # p, r, ap, f1, ap_class = ap_per_class(*stats)
  186. # p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95]
  187. # mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
  188. # nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
  189. # else:
  190. # nt = torch.zeros(1)
  191. #
  192. # # Print results
  193. # pf = '%20s' + '%12.3g' * 6 # print format
  194. # print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
  195. #
  196. # # Print results per class
  197. # if verbose and nc > 1 and len(stats):
  198. # for i, c in enumerate(ap_class):
  199. # print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
  200. #
  201. # # Print speeds
  202. # t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + (imgsz, imgsz, batch_size) # tuple
  203. # if not training:
  204. # print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)
  205. #
  206. # # Save JSON
  207. # if save_json and len(jdict):
  208. # f = 'detections_val2017_%s_results.json' % \
  209. # (weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename
  210. # print('\nCOCO mAP with pycocotools... saving %s...' % f)
  211. # with open(f, 'w') as file:
  212. # json.dump(jdict, file)
  213. #
  214. # try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
  215. # from pycocotools.coco import COCO
  216. # from pycocotools.cocoeval import COCOeval
  217. #
  218. # imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files]
  219. # cocoGt = COCO(glob.glob('../coco/annotations/instances_val*.json')[0]) # initialize COCO ground truth api
  220. # cocoDt = cocoGt.loadRes(f) # initialize COCO pred api
  221. # cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
  222. # cocoEval.params.imgIds = imgIds # image IDs to evaluate
  223. # cocoEval.evaluate()
  224. # cocoEval.accumulate()
  225. # cocoEval.summarize()
  226. # map, map50 = cocoEval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
  227. # except Exception as e:
  228. # print('ERROR: pycocotools unable to run: %s' % e)
  229. #
  230. # # Return results
  231. # model.float() # for training
  232. # maps = np.zeros(nc) + map
  233. # for i, c in enumerate(ap_class):
  234. # maps[c] = ap[i]
  235. # return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t
  236. if __name__ == '__main__':
  237. parser = argparse.ArgumentParser(prog='test.py')
  238. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  239. parser.add_argument('--data', type=str, default='data/coco.yaml', help='*.data path')
  240. parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
  241. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  242. parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
  243. parser.add_argument('--iou-thres', type=float, default=0.65, help='IOU threshold for NMS')
  244. parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file')
  245. parser.add_argument('--task', default='val', help="'val', 'test', 'study'")
  246. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  247. parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
  248. parser.add_argument('--augment', action='store_true', help='augmented inference')
  249. parser.add_argument('--merge', action='store_true', help='use Merge NMS')
  250. parser.add_argument('--verbose', action='store_true', help='report mAP by class')
  251. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  252. opt = parser.parse_args()
  253. opt.save_json |= opt.data.endswith('coco.yaml')
  254. opt.data = check_file(opt.data) # check file
  255. print(opt)
  256. if opt.task in ['val', 'test']: # run normally
  257. test(opt.data,
  258. opt.weights,
  259. opt.batch_size,
  260. opt.img_size,
  261. opt.conf_thres,
  262. opt.iou_thres,
  263. opt.save_json,
  264. opt.single_cls,
  265. opt.augment,
  266. opt.verbose)
  267. elif opt.task == 'study': # run over a range of settings and save/plot
  268. for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  269. f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
  270. x = list(range(320, 800, 64)) # x axis
  271. y = [] # y axis
  272. for i in x: # img-size
  273. print('\nRunning %s point %s...' % (f, i))
  274. r, _, t = test(opt.data, weights, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json)
  275. y.append(r + t) # results and times
  276. np.savetxt(f, y, fmt='%10.4g') # save
  277. os.system('zip -r study.zip study_*.txt')
  278. # utils.general.plot_study_txt(f, x) # plot