You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

8 ay önce
  1. import numpy as np
  2. import pandas as pd
  3. import os
  4. from PIL import Image
  5. from tqdm import tqdm
  6. from ensemble_boxes import *
  7. def xywh2x1y1x2y2(bbox):
  8. x1 = bbox[0] - bbox[2]/2
  9. x2 = bbox[0] + bbox[2]/2
  10. y1 = bbox[1] - bbox[3]/2
  11. y2 = bbox[1] + bbox[3]/2
  12. return ([x1,y1,x2,y2])
  13. def x1y1x2y22xywh(bbox):
  14. x = (bbox[0] + bbox[2])/2
  15. y = (bbox[1] + bbox[3])/2
  16. w = bbox[2] - bbox[0]
  17. h = bbox[3] - bbox[1]
  18. return ([x,y,w,h])
  19. IMG_PATH = '/VisDrone2019-DET-test-challenge/images/'
  20. TXT_PATH = './runs/val/'
  21. OUT_PATH = './runs/wbf_labels/'
  22. MODEL_NAME = os.listdir(TXT_PATH)
  23. # MODEL_NAME = ['test1','test2']
  24. # ===============================
  25. # Default WBF config (you can change these)
  26. iou_thr = 0.67 #0.67
  27. skip_box_thr = 0.01
  28. # skip_box_thr = 0.0001
  29. sigma = 0.1
  30. # boxes_list, scores_list, labels_list, weights=weights,
  31. # ===============================
  32. image_ids = os.listdir(IMG_PATH)
  33. for image_id in tqdm(image_ids, total=len(image_ids)):
  34. boxes_list = []
  35. scores_list = []
  36. labels_list = []
  37. weights = []
  38. for name in MODEL_NAME:
  39. box_list = []
  40. score_list = []
  41. label_list = []
  42. txt_file = TXT_PATH + name + '/labels/' + image_id.replace('jpg', 'txt')
  43. if os.path.exists(txt_file):
  44. # if os.path.getsize(txt_file) > 0:
  45. txt_df = pd.read_csv(txt_file,header=None,sep=' ').values
  46. for row in txt_df:
  47. box_list.append(xywh2x1y1x2y2(row[1:5]))
  48. score_list.append(row[5])
  49. label_list.append(int(row[0]))
  50. boxes_list.append(box_list)
  51. scores_list.append(score_list)
  52. labels_list.append(label_list)
  53. weights.append(1.0)
  54. else:
  55. continue
  56. # print(txt_file)
  57. boxes, scores, labels = weighted_boxes_fusion(boxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
  58. if not os.path.exists(OUT_PATH):
  59. os.makedirs(OUT_PATH)
  60. out_file = open(OUT_PATH + image_id.replace('jpg', 'txt'), 'w')
  61. for i,row in enumerate(boxes):
  62. img = Image.open(IMG_PATH + image_id)
  63. img_size = img.size
  64. bbox = x1y1x2y22xywh(row)
  65. out_file.write(str(int(labels[i]+1)) + ' ' +" ".join(str(x) for x in bbox) + " " + str(round(scores[i],6)) + '\n')
  66. out_file.close()