You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 3 година
пре 4 година
пре 3 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
пре 4 година
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import argparse
  2. import glob
  3. import json
  4. import os
  5. import shutil
  6. from pathlib import Path
  7. import numpy as np
  8. import torch
  9. import yaml
  10. from tqdm import tqdm
  11. from models.experimental import attempt_load
  12. from utils.datasets import create_dataloader
  13. from utils.general import (
  14. coco80_to_coco91_class, check_dataset, check_file, check_img_size, compute_loss, non_max_suppression, scale_coords,
  15. xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, ap_per_class, set_logging)
  16. from utils.torch_utils import select_device, time_synchronized
  17. def test(data,
  18. weights=None,
  19. batch_size=16,
  20. imgsz=640,
  21. conf_thres=0.001,
  22. iou_thres=0.6, # for NMS
  23. save_json=False,
  24. single_cls=False,
  25. augment=False,
  26. verbose=False,
  27. model=None,
  28. dataloader=None,
  29. save_dir=Path(''), # for saving images
  30. save_txt=False, # for auto-labelling
  31. save_conf=False,
  32. plots=True,
  33. log_imgs=0): # number of logged images
  34. # Initialize/load model and set device
  35. training = model is not None
  36. if training: # called by train.py
  37. device = next(model.parameters()).device # get model device
  38. else: # called directly
  39. set_logging()
  40. device = select_device(opt.device, batch_size=batch_size)
  41. save_txt = opt.save_txt # save *.txt labels
  42. # Remove previous
  43. if os.path.exists(save_dir):
  44. shutil.rmtree(save_dir) # delete dir
  45. os.makedirs(save_dir) # make new dir
  46. if save_txt:
  47. out = save_dir / 'autolabels'
  48. if os.path.exists(out):
  49. shutil.rmtree(out) # delete dir
  50. os.makedirs(out) # make new dir
  51. # Load model
  52. model = attempt_load(weights, map_location=device) # load FP32 model
  53. imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
  54. # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
  55. # if device.type != 'cpu' and torch.cuda.device_count() > 1:
  56. # model = nn.DataParallel(model)
  57. # Half
  58. half = device.type != 'cpu' # half precision only supported on CUDA
  59. if half:
  60. model.half()
  61. # Configure
  62. model.eval()
  63. with open(data) as f:
  64. data = yaml.load(f, Loader=yaml.FullLoader) # model dict
  65. check_dataset(data) # check
  66. nc = 1 if single_cls else int(data['nc']) # number of classes
  67. iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
  68. niou = iouv.numel()
  69. # Logging
  70. log_imgs = min(log_imgs, 100) # ceil
  71. try:
  72. import wandb # Weights & Biases
  73. except ImportError:
  74. log_imgs = 0
  75. # Dataloader
  76. if not training:
  77. img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
  78. _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
  79. path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
  80. dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt,
  81. hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0]
  82. seen = 0
  83. names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
  84. coco91class = coco80_to_coco91_class()
  85. s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
  86. p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
  87. loss = torch.zeros(3, device=device)
  88. jdict, stats, ap, ap_class, wandb_images = [], [], [], [], []
  89. for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
  90. img = img.to(device, non_blocking=True)
  91. img = img.half() if half else img.float() # uint8 to fp16/32
  92. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  93. targets = targets.to(device)
  94. nb, _, height, width = img.shape # batch size, channels, height, width
  95. whwh = torch.Tensor([width, height, width, height]).to(device)
  96. # Disable gradients
  97. with torch.no_grad():
  98. # Run model
  99. t = time_synchronized()
  100. inf_out, train_out = model(img, augment=augment) # inference and training outputs
  101. t0 += time_synchronized() - t
  102. # Compute loss
  103. if training: # if model has loss hyperparameters
  104. loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls
  105. # Run NMS
  106. t = time_synchronized()
  107. output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
  108. t1 += time_synchronized() - t
  109. # Statistics per image
  110. for si, pred in enumerate(output):
  111. labels = targets[targets[:, 0] == si, 1:]
  112. nl = len(labels)
  113. tcls = labels[:, 0].tolist() if nl else [] # target class
  114. seen += 1
  115. if pred is None:
  116. if nl:
  117. stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
  118. continue
  119. # Append to text file
  120. if save_txt:
  121. gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0]] # normalization gain whwh
  122. x = pred.clone()
  123. x[:, :4] = scale_coords(img[si].shape[1:], x[:, :4], shapes[si][0], shapes[si][1]) # to original
  124. for *xyxy, conf, cls in x:
  125. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  126. line = (cls, conf, *xywh) if save_conf else (cls, *xywh) # label format
  127. with open(str(out / Path(paths[si]).stem) + '.txt', 'a') as f:
  128. f.write(('%g ' * len(line) + '\n') % line)
  129. # W&B logging
  130. if len(wandb_images) < log_imgs:
  131. box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
  132. "class_id": int(cls),
  133. "box_caption": "%s %.3f" % (names[cls], conf),
  134. "scores": {"class_score": conf},
  135. "domain": "pixel"} for *xyxy, conf, cls in pred.clone().tolist()]
  136. boxes = {"predictions": {"box_data": box_data, "class_labels": names}}
  137. wandb_images.append(wandb.Image(img[si], boxes=boxes))
  138. # Clip boxes to image bounds
  139. clip_coords(pred, (height, width))
  140. # Append to pycocotools JSON dictionary
  141. if save_json:
  142. # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
  143. image_id = Path(paths[si]).stem
  144. box = pred[:, :4].clone() # xyxy
  145. scale_coords(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape
  146. box = xyxy2xywh(box) # xywh
  147. box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
  148. for p, b in zip(pred.tolist(), box.tolist()):
  149. jdict.append({'image_id': int(image_id) if image_id.isnumeric() else image_id,
  150. 'category_id': coco91class[int(p[5])],
  151. 'bbox': [round(x, 3) for x in b],
  152. 'score': round(p[4], 5)})
  153. # Assign all predictions as incorrect
  154. correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)
  155. if nl:
  156. detected = [] # target indices
  157. tcls_tensor = labels[:, 0]
  158. # target boxes
  159. tbox = xywh2xyxy(labels[:, 1:5]) * whwh
  160. # Per target class
  161. for cls in torch.unique(tcls_tensor):
  162. ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices
  163. pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # target indices
  164. # Search for detections
  165. if pi.shape[0]:
  166. # Prediction to target ious
  167. ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
  168. # Append detections
  169. detected_set = set()
  170. for j in (ious > iouv[0]).nonzero(as_tuple=False):
  171. d = ti[i[j]] # detected target
  172. if d.item() not in detected_set:
  173. detected_set.add(d.item())
  174. detected.append(d)
  175. correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
  176. if len(detected) == nl: # all targets already located in image
  177. break
  178. # Append statistics (correct, conf, pcls, tcls)
  179. stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
  180. # Plot images
  181. if plots and batch_i < 1:
  182. f = save_dir / f'test_batch{batch_i}_gt.jpg' # filename
  183. plot_images(img, targets, paths, str(f), names) # ground truth
  184. f = save_dir / f'test_batch{batch_i}_pred.jpg'
  185. plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
  186. # W&B logging
  187. if wandb_images:
  188. wandb.log({"outputs": wandb_images})
  189. # Compute statistics
  190. stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
  191. if len(stats) and stats[0].any():
  192. p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, fname=save_dir / 'precision-recall_curve.png')
  193. p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95]
  194. mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
  195. nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
  196. else:
  197. nt = torch.zeros(1)
  198. # Print results
  199. pf = '%20s' + '%12.3g' * 6 # print format
  200. print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
  201. # Print results per class
  202. if verbose and nc > 1 and len(stats):
  203. for i, c in enumerate(ap_class):
  204. print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
  205. # Print speeds
  206. t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + (imgsz, imgsz, batch_size) # tuple
  207. if not training:
  208. print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)
  209. # Save JSON
  210. if save_json and len(jdict):
  211. w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
  212. file = save_dir / f"detections_val2017_{w}_results.json" # predicted annotations file
  213. print('\nCOCO mAP with pycocotools... saving %s...' % file)
  214. with open(file, 'w') as f:
  215. json.dump(jdict, f)
  216. try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
  217. from pycocotools.coco import COCO
  218. from pycocotools.cocoeval import COCOeval
  219. imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files]
  220. cocoGt = COCO(glob.glob('../coco/annotations/instances_val*.json')[0]) # initialize COCO ground truth api
  221. cocoDt = cocoGt.loadRes(str(file)) # initialize COCO pred api
  222. cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
  223. cocoEval.params.imgIds = imgIds # image IDs to evaluate
  224. cocoEval.evaluate()
  225. cocoEval.accumulate()
  226. cocoEval.summarize()
  227. map, map50 = cocoEval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
  228. except Exception as e:
  229. print('ERROR: pycocotools unable to run: %s' % e)
  230. # Return results
  231. model.float() # for training
  232. maps = np.zeros(nc) + map
  233. for i, c in enumerate(ap_class):
  234. maps[c] = ap[i]
  235. return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t
  236. if __name__ == '__main__':
  237. parser = argparse.ArgumentParser(prog='test.py')
  238. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  239. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
  240. parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
  241. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  242. parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
  243. parser.add_argument('--iou-thres', type=float, default=0.65, help='IOU threshold for NMS')
  244. parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file')
  245. parser.add_argument('--task', default='val', help="'val', 'test', 'study'")
  246. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  247. parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
  248. parser.add_argument('--augment', action='store_true', help='augmented inference')
  249. parser.add_argument('--verbose', action='store_true', help='report mAP by class')
  250. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  251. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  252. parser.add_argument('--save-dir', type=str, default='runs/test', help='directory to save results')
  253. opt = parser.parse_args()
  254. opt.save_json |= opt.data.endswith('coco.yaml')
  255. opt.data = check_file(opt.data) # check file
  256. print(opt)
  257. if opt.task in ['val', 'test']: # run normally
  258. test(opt.data,
  259. opt.weights,
  260. opt.batch_size,
  261. opt.img_size,
  262. opt.conf_thres,
  263. opt.iou_thres,
  264. opt.save_json,
  265. opt.single_cls,
  266. opt.augment,
  267. opt.verbose,
  268. save_dir=Path(opt.save_dir),
  269. save_txt=opt.save_txt,
  270. save_conf=opt.save_conf,
  271. )
  272. print('Results saved to %s' % opt.save_dir)
  273. elif opt.task == 'study': # run over a range of settings and save/plot
  274. for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  275. f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
  276. x = list(range(320, 800, 64)) # x axis
  277. y = [] # y axis
  278. for i in x: # img-size
  279. print('\nRunning %s point %s...' % (f, i))
  280. r, _, t = test(opt.data, weights, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json)
  281. y.append(r + t) # results and times
  282. np.savetxt(f, y, fmt='%10.4g') # save
  283. os.system('zip -r study.zip study_*.txt')
  284. # utils.general.plot_study_txt(f, x) # plot