无人机视角的行人小目标检测
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

77 lignes
2.3KB

  1. import numpy as np
  2. import pandas as pd
  3. import os
  4. from PIL import Image
  5. from tqdm import tqdm
  6. from ensemble_boxes import *
  7. def xywh2x1y1x2y2(bbox):
  8. x1 = bbox[0] - bbox[2]/2
  9. x2 = bbox[0] + bbox[2]/2
  10. y1 = bbox[1] - bbox[3]/2
  11. y2 = bbox[1] + bbox[3]/2
  12. return ([x1,y1,x2,y2])
  13. def x1y1x2y22xywh(bbox):
  14. x = (bbox[0] + bbox[2])/2
  15. y = (bbox[1] + bbox[3])/2
  16. w = bbox[2] - bbox[0]
  17. h = bbox[3] - bbox[1]
  18. return ([x,y,w,h])
  19. IMG_PATH = '/VisDrone2019-DET-test-challenge/images/'
  20. TXT_PATH = './runs/val/'
  21. OUT_PATH = './runs/wbf_labels/'
  22. MODEL_NAME = os.listdir(TXT_PATH)
  23. # MODEL_NAME = ['test1','test2']
  24. # ===============================
  25. # Default WBF config (you can change these)
  26. iou_thr = 0.67 #0.67
  27. skip_box_thr = 0.01
  28. # skip_box_thr = 0.0001
  29. sigma = 0.1
  30. # boxes_list, scores_list, labels_list, weights=weights,
  31. # ===============================
  32. image_ids = os.listdir(IMG_PATH)
  33. for image_id in tqdm(image_ids, total=len(image_ids)):
  34. boxes_list = []
  35. scores_list = []
  36. labels_list = []
  37. weights = []
  38. for name in MODEL_NAME:
  39. box_list = []
  40. score_list = []
  41. label_list = []
  42. txt_file = TXT_PATH + name + '/labels/' + image_id.replace('jpg', 'txt')
  43. if os.path.exists(txt_file):
  44. # if os.path.getsize(txt_file) > 0:
  45. txt_df = pd.read_csv(txt_file,header=None,sep=' ').values
  46. for row in txt_df:
  47. box_list.append(xywh2x1y1x2y2(row[1:5]))
  48. score_list.append(row[5])
  49. label_list.append(int(row[0]))
  50. boxes_list.append(box_list)
  51. scores_list.append(score_list)
  52. labels_list.append(label_list)
  53. weights.append(1.0)
  54. else:
  55. continue
  56. # print(txt_file)
  57. boxes, scores, labels = weighted_boxes_fusion(boxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
  58. if not os.path.exists(OUT_PATH):
  59. os.makedirs(OUT_PATH)
  60. out_file = open(OUT_PATH + image_id.replace('jpg', 'txt'), 'w')
  61. for i,row in enumerate(boxes):
  62. img = Image.open(IMG_PATH + image_id)
  63. img_size = img.size
  64. bbox = x1y1x2y22xywh(row)
  65. out_file.write(str(int(labels[i]+1)) + ' ' +" ".join(str(x) for x in bbox) + " " + str(round(scores[i],6)) + '\n')
  66. out_file.close()