75 lines
3.5 KiB
Python
75 lines
3.5 KiB
Python
import cv2
|
||
import numpy as np
|
||
import torch
|
||
from loguru import logger
|
||
|
||
def security_post_process(preds, pars):
|
||
# pars={'solar':0}
|
||
'''
|
||
将光伏板上覆盖物、裂缝识别出来
|
||
'''
|
||
# print(preds[0])
|
||
#logger.info('\n分类结果返回:%s'%preds)
|
||
if not preds[0]:
|
||
return [[],'']
|
||
preds = np.array(preds[0])
|
||
preds[:,5] = np.round(preds[:,5]).astype(int)
|
||
person = preds[preds[:, 5] == pars['objs'][0]]
|
||
helmet = preds[preds[:, 5] == pars['objs'][1]]
|
||
other = preds[~np.isin(preds[:,5],pars['objs'])]
|
||
# 重塑为(b1数量, 1, 4)和(1, b2数量, 4)以支持广播
|
||
person_coords = person[:, :4].reshape(1, -1, 4) # 形状: (3, 1, 4)
|
||
helmet_coords = helmet[:, :4].reshape(-1, 1, 4) # 形状: (1, 2, 4)
|
||
|
||
# 批量计算交集坐标
|
||
inter_x1 = np.maximum(person_coords[..., 0], helmet_coords[..., 0]) # 形状: (3, 2)
|
||
inter_y1 = np.maximum(person_coords[..., 1], helmet_coords[..., 1])
|
||
inter_x2 = np.minimum(person_coords[..., 2], helmet_coords[..., 2])
|
||
inter_y2 = np.minimum(person_coords[..., 3], helmet_coords[..., 3])
|
||
|
||
# 计算交集面积(确保非负)
|
||
inter_area = np.maximum(0, inter_x2 - inter_x1) * np.maximum(0, inter_y2 - inter_y1) # 形状: (3, 2)
|
||
|
||
# 计算b1和b2各自的面积
|
||
person_area = (person_coords[..., 2] - person_coords[..., 0]) * (person_coords[..., 3] - person_coords[..., 1]) # 形状: (3, 1)
|
||
helmet_area = (helmet_coords[..., 2] - helmet_coords[..., 0]) * (helmet_coords[..., 3] - helmet_coords[..., 1]) # 形状: (1, 2)
|
||
|
||
# 计算并集面积和交并比(IoU)
|
||
union_area = person_area + helmet_area - inter_area + 0.001 # 形状: (3, 2)
|
||
iou = inter_area / union_area # 形状: (3, 2),每行对应b1元素,每列对应b2元素
|
||
|
||
# 找到b2中与任何b1元素IoU>0.25的索引
|
||
# 对每行(b1元素)取最大值,再判断是否有任何b1元素满足条件
|
||
person_mask = np.any(iou > pars['iou'], axis=0) # 形状: (2,),True表示符合条件
|
||
# logger.info('异常person_mask:',person_mask)
|
||
# logger.info('异常person:', person)
|
||
inhelmet = person[person_mask]
|
||
unhelmet = person[~person_mask]
|
||
unhelmet[:,5] = pars['unhelmet']
|
||
preds = np.row_stack((inhelmet,helmet,unhelmet,other))
|
||
|
||
|
||
return [preds.tolist(),'']
|
||
|
||
if __name__ == "__main__":
|
||
# 对应DJI_20230306140129_0001_Z_165.jpg检测结果
|
||
# preds=[[6.49000e+02, 2.91000e+02, 1.07900e+03, 7.33000e+02, 9.08165e-01, 3.00000e+00],
|
||
# [8.11000e+02, 2.99000e+02, 1.31200e+03, 7.65000e+02, 8.61268e-01, 3.00000e+00],
|
||
# [7.05000e+02, 1.96000e+02, 7.19000e+02, 2.62000e+02, 5.66877e-01, 0.00000e+00]]
|
||
|
||
# 对应DJI_20230306152702_0001_Z_562.jpg检测结果
|
||
preds = [[7.62000e+02, 7.14000e+02, 1.82800e+03, 9.51000e+02, 9.00902e-01, 3.00000e+00],
|
||
[2.00000e+01, 3.45000e+02, 1.51300e+03, 6.71000e+02, 8.81440e-01, 3.00000e+00],
|
||
[8.35000e+02, 8.16000e+02, 8.53000e+02, 8.30000e+02, 7.07651e-01, 0.00000e+00],
|
||
[1.35600e+03, 4.56000e+02, 1.42800e+03, 4.94000e+02, 6.70549e-01, 2.00000e+00]]
|
||
print('before :\n ', preds)
|
||
# preds=torch.tensor(preds) #返回的预测结果
|
||
imgwidth = 1920
|
||
imgheight = 1680
|
||
pars = {'imgSize': (imgwidth, imgheight), 'wRation': 1 / 6.0, 'hRation': 1 / 6.0, 'smallId': 0, 'bigId': 3,
|
||
'newId': 4, 'recScale': 1.2}
|
||
# 'smallId':0(国旗),'bigId':3(船只),wRation和hRation表示判断的阈值条件,newId--新目标的id
|
||
# yyy = channel2_post_process([preds], pars) # 送入后处理函数
|
||
#
|
||
# print('after :\n ', yyy)
|