Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

294 linhas
13KB

  1. import argparse
  2. import glob
  3. import json
  4. import os
  5. import shutil
  6. from pathlib import Path
  7. import numpy as np
  8. import torch
  9. import yaml
  10. from tqdm import tqdm
  11. from models.experimental import attempt_load
  12. from utils.datasets import create_dataloader
  13. from utils.general import (
  14. coco80_to_coco91_class, check_dataset, check_file, check_img_size, compute_loss, non_max_suppression, scale_coords,
  15. xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, ap_per_class, set_logging)
  16. from utils.torch_utils import select_device, time_synchronized
  17. def test(data,
  18. weights=None,
  19. batch_size=16,
  20. imgsz=640,
  21. conf_thres=0.001,
  22. iou_thres=0.6, # for NMS
  23. save_json=False,
  24. single_cls=False,
  25. augment=False,
  26. verbose=False,
  27. model=None,
  28. dataloader=None,
  29. save_dir='',
  30. merge=False,
  31. save_txt=False):
  32. # Initialize/load model and set device
  33. training = model is not None
  34. if training: # called by train.py
  35. device = next(model.parameters()).device # get model device
  36. else: # called directly
  37. set_logging()
  38. device = select_device(opt.device, batch_size=batch_size)
  39. merge, save_txt = opt.merge, opt.save_txt # use Merge NMS, save *.txt labels
  40. if save_txt:
  41. out = Path('inference/output')
  42. if os.path.exists(out):
  43. shutil.rmtree(out) # delete output folder
  44. os.makedirs(out) # make new output folder
  45. # Remove previous
  46. for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')):
  47. os.remove(f)
  48. # Load model
  49. model = attempt_load(weights, map_location=device) # load FP32 model
  50. imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
  51. # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
  52. # if device.type != 'cpu' and torch.cuda.device_count() > 1:
  53. # model = nn.DataParallel(model)
  54. # Half
  55. half = device.type != 'cpu' # half precision only supported on CUDA
  56. if half:
  57. model.half()
  58. # Configure
  59. model.eval()
  60. with open(data) as f:
  61. data = yaml.load(f, Loader=yaml.FullLoader) # model dict
  62. check_dataset(data) # check
  63. nc = 1 if single_cls else int(data['nc']) # number of classes
  64. iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
  65. niou = iouv.numel()
  66. # Dataloader
  67. if not training:
  68. img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
  69. _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
  70. path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
  71. dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt,
  72. hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0]
  73. seen = 0
  74. names = model.names if hasattr(model, 'names') else model.module.names
  75. coco91class = coco80_to_coco91_class()
  76. s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
  77. p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
  78. loss = torch.zeros(3, device=device)
  79. jdict, stats, ap, ap_class = [], [], [], []
  80. for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
  81. img = img.to(device, non_blocking=True)
  82. img = img.half() if half else img.float() # uint8 to fp16/32
  83. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  84. targets = targets.to(device)
  85. nb, _, height, width = img.shape # batch size, channels, height, width
  86. whwh = torch.Tensor([width, height, width, height]).to(device)
  87. # Disable gradients
  88. with torch.no_grad():
  89. # Run model
  90. t = time_synchronized()
  91. inf_out, train_out = model(img, augment=augment) # inference and training outputs
  92. t0 += time_synchronized() - t
  93. # Compute loss
  94. if training: # if model has loss hyperparameters
  95. loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls
  96. # Run NMS
  97. t = time_synchronized()
  98. output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
  99. t1 += time_synchronized() - t
  100. # Statistics per image
  101. for si, pred in enumerate(output):
  102. labels = targets[targets[:, 0] == si, 1:]
  103. nl = len(labels)
  104. tcls = labels[:, 0].tolist() if nl else [] # target class
  105. seen += 1
  106. if pred is None:
  107. if nl:
  108. stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
  109. continue
  110. # Append to text file
  111. if save_txt:
  112. gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0]] # normalization gain whwh
  113. txt_path = str(out / Path(paths[si]).stem)
  114. pred[:, :4] = scale_coords(img[si].shape[1:], pred[:, :4], shapes[si][0], shapes[si][1]) # to original
  115. for *xyxy, conf, cls in pred:
  116. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  117. with open(txt_path + '.txt', 'a') as f:
  118. f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
  119. # Clip boxes to image bounds
  120. clip_coords(pred, (height, width))
  121. # Append to pycocotools JSON dictionary
  122. if save_json:
  123. # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
  124. image_id = Path(paths[si]).stem
  125. box = pred[:, :4].clone() # xyxy
  126. scale_coords(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape
  127. box = xyxy2xywh(box) # xywh
  128. box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
  129. for p, b in zip(pred.tolist(), box.tolist()):
  130. jdict.append({'image_id': int(image_id) if image_id.isnumeric() else image_id,
  131. 'category_id': coco91class[int(p[5])],
  132. 'bbox': [round(x, 3) for x in b],
  133. 'score': round(p[4], 5)})
  134. # Assign all predictions as incorrect
  135. correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)
  136. if nl:
  137. detected = [] # target indices
  138. tcls_tensor = labels[:, 0]
  139. # target boxes
  140. tbox = xywh2xyxy(labels[:, 1:5]) * whwh
  141. # Per target class
  142. for cls in torch.unique(tcls_tensor):
  143. ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices
  144. pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # target indices
  145. # Search for detections
  146. if pi.shape[0]:
  147. # Prediction to target ious
  148. ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
  149. # Append detections
  150. for j in (ious > iouv[0]).nonzero(as_tuple=False):
  151. d = ti[i[j]] # detected target
  152. if d not in detected:
  153. detected.append(d)
  154. correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
  155. if len(detected) == nl: # all targets already located in image
  156. break
  157. # Append statistics (correct, conf, pcls, tcls)
  158. stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
  159. # Plot images
  160. if batch_i < 1:
  161. f = Path(save_dir) / ('test_batch%g_gt.jpg' % batch_i) # filename
  162. plot_images(img, targets, paths, str(f), names) # ground truth
  163. f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i)
  164. plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
  165. # Compute statistics
  166. stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
  167. if len(stats) and stats[0].any():
  168. p, r, ap, f1, ap_class = ap_per_class(*stats)
  169. p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95]
  170. mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
  171. nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
  172. else:
  173. nt = torch.zeros(1)
  174. # Print results
  175. pf = '%20s' + '%12.3g' * 6 # print format
  176. print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
  177. # Print results per class
  178. if verbose and nc > 1 and len(stats):
  179. for i, c in enumerate(ap_class):
  180. print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
  181. # Print speeds
  182. t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + (imgsz, imgsz, batch_size) # tuple
  183. if not training:
  184. print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)
  185. # Save JSON
  186. if save_json and len(jdict):
  187. f = 'detections_val2017_%s_results.json' % \
  188. (weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename
  189. print('\nCOCO mAP with pycocotools... saving %s...' % f)
  190. with open(f, 'w') as file:
  191. json.dump(jdict, file)
  192. try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
  193. from pycocotools.coco import COCO
  194. from pycocotools.cocoeval import COCOeval
  195. imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files]
  196. cocoGt = COCO(glob.glob('../coco/annotations/instances_val*.json')[0]) # initialize COCO ground truth api
  197. cocoDt = cocoGt.loadRes(f) # initialize COCO pred api
  198. cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
  199. cocoEval.params.imgIds = imgIds # image IDs to evaluate
  200. cocoEval.evaluate()
  201. cocoEval.accumulate()
  202. cocoEval.summarize()
  203. map, map50 = cocoEval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
  204. except Exception as e:
  205. print('ERROR: pycocotools unable to run: %s' % e)
  206. # Return results
  207. model.float() # for training
  208. maps = np.zeros(nc) + map
  209. for i, c in enumerate(ap_class):
  210. maps[c] = ap[i]
  211. return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t
  212. if __name__ == '__main__':
  213. parser = argparse.ArgumentParser(prog='test.py')
  214. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  215. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
  216. parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
  217. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  218. parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
  219. parser.add_argument('--iou-thres', type=float, default=0.65, help='IOU threshold for NMS')
  220. parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file')
  221. parser.add_argument('--task', default='val', help="'val', 'test', 'study'")
  222. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  223. parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
  224. parser.add_argument('--augment', action='store_true', help='augmented inference')
  225. parser.add_argument('--merge', action='store_true', help='use Merge NMS')
  226. parser.add_argument('--verbose', action='store_true', help='report mAP by class')
  227. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  228. opt = parser.parse_args()
  229. opt.save_json |= opt.data.endswith('coco.yaml')
  230. opt.data = check_file(opt.data) # check file
  231. print(opt)
  232. if opt.task in ['val', 'test']: # run normally
  233. test(opt.data,
  234. opt.weights,
  235. opt.batch_size,
  236. opt.img_size,
  237. opt.conf_thres,
  238. opt.iou_thres,
  239. opt.save_json,
  240. opt.single_cls,
  241. opt.augment,
  242. opt.verbose)
  243. elif opt.task == 'study': # run over a range of settings and save/plot
  244. for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  245. f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
  246. x = list(range(320, 800, 64)) # x axis
  247. y = [] # y axis
  248. for i in x: # img-size
  249. print('\nRunning %s point %s...' % (f, i))
  250. r, _, t = test(opt.data, weights, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json)
  251. y.append(r + t) # results and times
  252. np.savetxt(f, y, fmt='%10.4g') # save
  253. os.system('zip -r study.zip study_*.txt')
  254. # utils.general.plot_study_txt(f, x) # plot