You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

crowdGather.py 6.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import sys
  2. from pathlib import Path
  3. import math
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import math
  8. import time
  9. FILE = Path(__file__).absolute()
  10. #sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
  11. def calculate_distance(point1, point2):
  12. """计算两个点之间的欧氏距离"""
  13. point= center_coordinate(point1)
  14. point=np.array(point)
  15. other_point = center_coordinate(point2)
  16. other_point = np.array(other_point)
  17. return np.linalg.norm(point - other_point)
  18. def find_clusters(preds, min_distance):
  19. """按照最小距离将点分成簇"""
  20. points=preds
  21. points=np.array(points)
  22. clusters = []
  23. used_points = set()
  24. for i, point in enumerate(points):
  25. if i not in used_points: # 如果该点未被使用过
  26. cluster = [point]
  27. used_points.add(i)
  28. # 寻找与该点距离小于等于min_distance的其他点
  29. for j, other_point in enumerate(points):
  30. if j not in used_points:
  31. if all(calculate_distance(point, other_point) <= min_distance
  32. for point in cluster):
  33. cluster.append(other_point)
  34. used_points.add(j)
  35. clusters.append(cluster)
  36. return clusters
  37. def center_coordinate(boundbxs):
  38. '''
  39. 根据检测矩形框,得到其矩形长度和宽度
  40. 输入:两个对角坐标xyxy
  41. 输出:矩形框重点坐标xy
  42. '''
  43. boundbxs_x1 = boundbxs[0]
  44. boundbxs_y1 = boundbxs[1]
  45. boundbxs_x2 = boundbxs[2]
  46. boundbxs_y2 = boundbxs[3]
  47. center_x = 0.5 * (boundbxs_x1 + boundbxs_x2)
  48. center_y = 0.5 * (boundbxs_y1 + boundbxs_y2)
  49. return [center_x, center_y]
  50. def get_bounding_rectangle(rectangles):
  51. '''
  52. 通过输入多个矩形的对角坐标,得到这几个矩形的外包矩形对角坐标
  53. 输入:点簇列表 (嵌套列表)
  54. 输出:多个矩形的外包矩形对角坐标 (列表)
  55. '''
  56. min_x, max_x, min_y, max_y = float('inf'), float('-inf'), float('inf'), float('-inf')
  57. for rect in rectangles:
  58. x1, y1, x2, y2,c1,t1 = rect
  59. min_x = min(min_x, min(x1, x2))
  60. max_x = max(max_x, max(x1, x2))
  61. min_y = min(min_y, min(y1, y2))
  62. max_y = max(max_y, max(y1, y2))
  63. return [min_x, min_y, max_x, max_y]
  64. def calculate_score(input_value):
  65. '''
  66. 计算人群聚集置信度,检测出3-10人内,按照0.85-1的上升趋势取值;
  67. 当检测超过10人,直接判断分数为1.
  68. '''
  69. if input_value == 3:
  70. output_value=0.85
  71. elif input_value == 4:
  72. output_value=0.9
  73. elif 5<= input_value <=10:
  74. output_value = 0.9+(input_value-4)*0.015
  75. else:
  76. output_value=1
  77. return output_value
  78. def gather_post_process(predsList, pars):
  79. '''
  80. 后处理程序,针对检测出的pedestrian,进行人员聚集的算法检测,按照类别'crowd_people'增加predsList
  81. ①原类别:
  82. ['ForestSpot', 'PestTree', 'pedestrian', 'fire', 'smog','cloud']=[0,1,2,3,4,5]
  83. ②处理后的类别汇总:
  84. ['ForestSpot', 'PestTree', 'pedestrian', 'fire', 'smog','cloud','crowd_people']=[0,1,2,3,4,5,6]
  85. 输入:
  86. preds 一张图像的检测结果,为嵌套列表,tensor,包括x_y_x_y_conf_class
  87. imgwidth,imgheight 图像的原始宽度及长度
  88. 输出:检测结果(将其中未悬挂国旗的显示)
  89. '''
  90. t0=time.time()
  91. predsList = predsList[0]
  92. predsList = [x for x in predsList if int(x[5]) !=5 ]##把类别“云朵”去除
  93. # 1、过滤掉类别2以外的目标,只保留行人
  94. preds = [ x for x in predsList if int(x[5]) ==pars['pedestrianId'] ]
  95. if len(preds)< pars['crowdThreshold']:
  96. return predsList,'gaher postTime:No gathering'
  97. preds = np.array(preds)
  98. longs = np.mean(np.max(preds[:,2:4]-preds[:,0:2]))
  99. distanceThreshold = pars['distancePersonScale']*longs
  100. # 2、查找点簇
  101. clusters = find_clusters(preds, distanceThreshold)
  102. clusters_crowd = []
  103. # 3、输出点簇信息,点簇中数量超过阈值,判断人员聚集
  104. for i, cluster in enumerate(clusters):
  105. if len(cluster) >= pars['crowdThreshold']: # 超过一定人数,即为人员聚集
  106. #print(f"Cluster {i + 1}: {len(cluster)} points")
  107. clusters_crowd.append(cluster)
  108. #print(clusters_crowd)
  109. # 4、根据得到的人员聚集点簇,合并其他类别检测结果
  110. for i in range(len(clusters_crowd)):
  111. xyxy = get_bounding_rectangle(clusters_crowd[i]) # 人群聚集包围框
  112. score = calculate_score(len(clusters_crowd[i])) # 人群聚集置信度
  113. xyxy.append(score) # 人群聚集置信度
  114. xyxy.append(pars['gatherId']) # 人群聚集类别
  115. predsList.append(xyxy)
  116. # 5、输出最终类别,共7类,用于绘图显示
  117. output_predslist = predsList
  118. #print('craoGaher line131:',output_predslist)
  119. t1=time.time()
  120. return output_predslist,'gaher postTime:%.1f ms'%( (t1-t0)*1000 )
  121. if __name__ == "__main__":
  122. t1 = time.time()
  123. # 对应vendor1_20240529_99.jpg检测结果
  124. preds=[[224.19933, 148.30751, 278.19156, 199.87828, 0.87625, 2.00000],
  125. [362.67139, 161.25760, 417.72357, 211.51706, 0.86919, 2.00000],
  126. [437.00131, 256.19083, 487.88870, 307.72897, 0.85786, 2.00000],
  127. [442.64606, 335.78168, 493.75720, 371.41418, 0.85245, 2.00000],
  128. [324.58362, 256.18488, 357.72626, 294.08929, 0.84512, 2.00000],
  129. [343.59781, 301.06506, 371.04105, 350.01086, 0.84207, 2.00000],
  130. [301.35858, 210.64088, 332.64862, 250.78883, 0.84063, 2.00000],
  131. [406.02994, 216.91214, 439.44455, 249.26077, 0.83698, 2.00000],
  132. [321.53494, 99.68467, 354.67477, 135.53226, 0.82515, 2.00000],
  133. [253.97131, 202.65234, 302.06055, 233.30634, 0.81498, 2.00000],
  134. [365.62521, 66.42108, 442.02292, 127.37558, 0.79556, 1.00000]]
  135. #preds=torch.tensor(preds) #返回的预测结果
  136. imgwidth=1920
  137. imgheight=1680
  138. pars={'imgSize':(imgwidth,imgheight),'pedestrianId':2,'crowdThreshold':4,'gatherId':6,'distancePersonScale':2.0}
  139. '''
  140. pedestrianId 为行人识别的类别;
  141. crowdThreshold为设置的判断人员聚集的人数阈值,默认4人为聚集
  142. distanceThreshold为设置的判断人员聚集的距离阈值,为了测试默认300像素内为聚集(可自行设置)
  143. '''
  144. yyy=gather_post_process(preds,pars) #送入后处理函数
  145. t2 = time.time()
  146. ttt = t2 - t1
  147. print('时间', ttt * 1000)