无人机视角的行人小目标检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

77 lines
2.3KB

  1. import numpy as np
  2. import pandas as pd
  3. import os
  4. from PIL import Image
  5. from tqdm import tqdm
  6. from ensemble_boxes import *
  7. def xywh2x1y1x2y2(bbox):
  8. x1 = bbox[0] - bbox[2]/2
  9. x2 = bbox[0] + bbox[2]/2
  10. y1 = bbox[1] - bbox[3]/2
  11. y2 = bbox[1] + bbox[3]/2
  12. return ([x1,y1,x2,y2])
  13. def x1y1x2y22xywh(bbox):
  14. x = (bbox[0] + bbox[2])/2
  15. y = (bbox[1] + bbox[3])/2
  16. w = bbox[2] - bbox[0]
  17. h = bbox[3] - bbox[1]
  18. return ([x,y,w,h])
  19. IMG_PATH = '/VisDrone2019-DET-test-challenge/images/'
  20. TXT_PATH = './runs/val/'
  21. OUT_PATH = './runs/wbf_labels/'
  22. MODEL_NAME = os.listdir(TXT_PATH)
  23. # MODEL_NAME = ['test1','test2']
  24. # ===============================
  25. # Default WBF config (you can change these)
  26. iou_thr = 0.67 #0.67
  27. skip_box_thr = 0.01
  28. # skip_box_thr = 0.0001
  29. sigma = 0.1
  30. # boxes_list, scores_list, labels_list, weights=weights,
  31. # ===============================
  32. image_ids = os.listdir(IMG_PATH)
  33. for image_id in tqdm(image_ids, total=len(image_ids)):
  34. boxes_list = []
  35. scores_list = []
  36. labels_list = []
  37. weights = []
  38. for name in MODEL_NAME:
  39. box_list = []
  40. score_list = []
  41. label_list = []
  42. txt_file = TXT_PATH + name + '/labels/' + image_id.replace('jpg', 'txt')
  43. if os.path.exists(txt_file):
  44. # if os.path.getsize(txt_file) > 0:
  45. txt_df = pd.read_csv(txt_file,header=None,sep=' ').values
  46. for row in txt_df:
  47. box_list.append(xywh2x1y1x2y2(row[1:5]))
  48. score_list.append(row[5])
  49. label_list.append(int(row[0]))
  50. boxes_list.append(box_list)
  51. scores_list.append(score_list)
  52. labels_list.append(label_list)
  53. weights.append(1.0)
  54. else:
  55. continue
  56. # print(txt_file)
  57. boxes, scores, labels = weighted_boxes_fusion(boxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
  58. if not os.path.exists(OUT_PATH):
  59. os.makedirs(OUT_PATH)
  60. out_file = open(OUT_PATH + image_id.replace('jpg', 'txt'), 'w')
  61. for i,row in enumerate(boxes):
  62. img = Image.open(IMG_PATH + image_id)
  63. img_size = img.size
  64. bbox = x1y1x2y22xywh(row)
  65. out_file.write(str(int(labels[i]+1)) + ' ' +" ".join(str(x) for x in bbox) + " " + str(round(scores[i],6)) + '\n')
  66. out_file.close()