You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

113 lines
4.4KB

  1. import os,sys
  2. import torch
  3. import numpy as np
  4. sys.path.extend(['../AIlib2/obbUtils'])
  5. #import datasets.DOTA_devkit.ResultMerge_multi_process
  6. #from datasets.DOTA_devkit.ResultMerge_multi_process import py_cpu_nms_poly_fast, py_cpu_nms_poly
  7. from dotadevkit.ops.ResultMerge import py_cpu_nms_poly_fast, py_cpu_nms_poly
  8. import time
  9. # def decode_prediction(predictions, dsets, args, img_id, down_ratio):
  10. def decode_prediction(predictions, category, model_size, down_ratio,ori_image):
  11. t1=time.time()
  12. predictions = predictions[0, :, :]
  13. # ttt1=time.time()
  14. # # ori_image = dsets.load_image(dsets.img_ids.index(img_id)) #加载了原图第2次????这里耗时 改1
  15. # ttt2 = time.time()
  16. # print(f'jiazaitupian. ({(1E3 * (ttt2 - ttt1)):.1f}ms) ')
  17. h, w, c = ori_image.shape
  18. pts0 = {cat: [] for cat in category}
  19. scores0 = {cat: [] for cat in category}
  20. for pred in predictions:
  21. cen_pt = np.asarray([pred[0], pred[1]], np.float32)
  22. tt = np.asarray([pred[2], pred[3]], np.float32)
  23. rr = np.asarray([pred[4], pred[5]], np.float32)
  24. bb = np.asarray([pred[6], pred[7]], np.float32)
  25. ll = np.asarray([pred[8], pred[9]], np.float32)
  26. tl = tt + ll - cen_pt
  27. bl = bb + ll - cen_pt
  28. tr = tt + rr - cen_pt
  29. br = bb + rr - cen_pt
  30. score = pred[10]
  31. clse = pred[11]
  32. pts = np.asarray([tr, br, bl, tl], np.float32)
  33. pts[:, 0] = pts[:, 0] * down_ratio / model_size[0] * w
  34. pts[:, 1] = pts[:, 1] * down_ratio / model_size[1] * h
  35. pts0[category[int(clse)]].append(pts)
  36. scores0[category[int(clse)]].append(score)
  37. t2=time.time()
  38. #print('###line40:decode_prediction time: %.1f ',(t2-t1)*1000.0)
  39. return pts0, scores0
  40. def non_maximum_suppression(pts, scores):
  41. nms_item = np.concatenate([pts[:, 0:1, 0],
  42. pts[:, 0:1, 1],
  43. pts[:, 1:2, 0],
  44. pts[:, 1:2, 1],
  45. pts[:, 2:3, 0],
  46. pts[:, 2:3, 1],
  47. pts[:, 3:4, 0],
  48. pts[:, 3:4, 1],
  49. scores[:, np.newaxis]], axis=1)
  50. nms_item = np.asarray(nms_item, np.float64)
  51. keep_index = py_cpu_nms_poly_fast(dets=nms_item, thresh=0.1)
  52. return nms_item[keep_index]
  53. def write_results(args,
  54. model,
  55. dsets,
  56. down_ratio,
  57. device,
  58. decoder,
  59. result_path,
  60. print_ps=False):
  61. results = {cat: {img_id: [] for img_id in dsets.img_ids} for cat in dsets.category}
  62. for index in range(len(dsets)):
  63. data_dict = dsets.__getitem__(index)
  64. image = data_dict['image'].to(device)
  65. img_id = data_dict['img_id']
  66. image_w = data_dict['image_w']
  67. image_h = data_dict['image_h']
  68. with torch.no_grad():
  69. pr_decs = model(image)
  70. decoded_pts = []
  71. decoded_scores = []
  72. torch.cuda.synchronize(device)
  73. predictions = decoder.ctdet_decode(pr_decs)
  74. pts0, scores0 = decode_prediction(predictions, dsets, args, img_id, down_ratio)
  75. decoded_pts.append(pts0)
  76. decoded_scores.append(scores0)
  77. # nms
  78. for cat in dsets.category:
  79. if cat == 'background':
  80. continue
  81. pts_cat = []
  82. scores_cat = []
  83. for pts0, scores0 in zip(decoded_pts, decoded_scores):
  84. pts_cat.extend(pts0[cat])
  85. scores_cat.extend(scores0[cat])
  86. pts_cat = np.asarray(pts_cat, np.float32)
  87. scores_cat = np.asarray(scores_cat, np.float32)
  88. if pts_cat.shape[0]:
  89. nms_results = non_maximum_suppression(pts_cat, scores_cat)
  90. results[cat][img_id].extend(nms_results)
  91. if print_ps:
  92. print('testing {}/{} data {}'.format(index+1, len(dsets), img_id))
  93. for cat in dsets.category:
  94. if cat == 'background':
  95. continue
  96. with open(os.path.join(result_path, 'Task1_{}.txt'.format(cat)), 'w') as f:
  97. for img_id in results[cat]:
  98. for pt in results[cat][img_id]:
  99. f.write('{} {:.12f} {:.1f} {:.1f} {:.1f} {:.1f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.format(
  100. img_id, pt[8], pt[0], pt[1], pt[2], pt[3], pt[4], pt[5], pt[6], pt[7]))