AIlib2/utilsK/securitypostUtils.py

75 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import torch
from loguru import logger
def security_post_process(preds, pars):
# pars={'solar':0}
'''
将光伏板上覆盖物、裂缝识别出来
'''
# print(preds[0])
#logger.info('\n分类结果返回%s'%preds)
if not preds[0]:
return [[],'']
preds = np.array(preds[0])
preds[:,5] = np.round(preds[:,5]).astype(int)
person = preds[preds[:, 5] == pars['objs'][0]]
helmet = preds[preds[:, 5] == pars['objs'][1]]
other = preds[~np.isin(preds[:,5],pars['objs'])]
# 重塑为(b1数量, 1, 4)和(1, b2数量, 4)以支持广播
person_coords = person[:, :4].reshape(1, -1, 4) # 形状: (3, 1, 4)
helmet_coords = helmet[:, :4].reshape(-1, 1, 4) # 形状: (1, 2, 4)
# 批量计算交集坐标
inter_x1 = np.maximum(person_coords[..., 0], helmet_coords[..., 0]) # 形状: (3, 2)
inter_y1 = np.maximum(person_coords[..., 1], helmet_coords[..., 1])
inter_x2 = np.minimum(person_coords[..., 2], helmet_coords[..., 2])
inter_y2 = np.minimum(person_coords[..., 3], helmet_coords[..., 3])
# 计算交集面积(确保非负)
inter_area = np.maximum(0, inter_x2 - inter_x1) * np.maximum(0, inter_y2 - inter_y1) # 形状: (3, 2)
# 计算b1和b2各自的面积
person_area = (person_coords[..., 2] - person_coords[..., 0]) * (person_coords[..., 3] - person_coords[..., 1]) # 形状: (3, 1)
helmet_area = (helmet_coords[..., 2] - helmet_coords[..., 0]) * (helmet_coords[..., 3] - helmet_coords[..., 1]) # 形状: (1, 2)
# 计算并集面积和交并比(IoU)
union_area = person_area + helmet_area - inter_area + 0.001 # 形状: (3, 2)
iou = inter_area / union_area # 形状: (3, 2)每行对应b1元素每列对应b2元素
# 找到b2中与任何b1元素IoU>0.25的索引
# 对每行b1元素取最大值再判断是否有任何b1元素满足条件
person_mask = np.any(iou > pars['iou'], axis=0) # 形状: (2,)True表示符合条件
# logger.info('异常person_mask:',person_mask)
# logger.info('异常person:', person)
inhelmet = person[person_mask]
unhelmet = person[~person_mask]
unhelmet[:,5] = pars['unhelmet']
preds = np.row_stack((inhelmet,helmet,unhelmet,other))
return [preds.tolist(),'']
if __name__ == "__main__":
# 对应DJI_20230306140129_0001_Z_165.jpg检测结果
# preds=[[6.49000e+02, 2.91000e+02, 1.07900e+03, 7.33000e+02, 9.08165e-01, 3.00000e+00],
# [8.11000e+02, 2.99000e+02, 1.31200e+03, 7.65000e+02, 8.61268e-01, 3.00000e+00],
# [7.05000e+02, 1.96000e+02, 7.19000e+02, 2.62000e+02, 5.66877e-01, 0.00000e+00]]
# 对应DJI_20230306152702_0001_Z_562.jpg检测结果
preds = [[7.62000e+02, 7.14000e+02, 1.82800e+03, 9.51000e+02, 9.00902e-01, 3.00000e+00],
[2.00000e+01, 3.45000e+02, 1.51300e+03, 6.71000e+02, 8.81440e-01, 3.00000e+00],
[8.35000e+02, 8.16000e+02, 8.53000e+02, 8.30000e+02, 7.07651e-01, 0.00000e+00],
[1.35600e+03, 4.56000e+02, 1.42800e+03, 4.94000e+02, 6.70549e-01, 2.00000e+00]]
print('before :\n ', preds)
# preds=torch.tensor(preds) #返回的预测结果
imgwidth = 1920
imgheight = 1680
pars = {'imgSize': (imgwidth, imgheight), 'wRation': 1 / 6.0, 'hRation': 1 / 6.0, 'smallId': 0, 'bigId': 3,
'newId': 4, 'recScale': 1.2}
# 'smallId':0(国旗)'bigId':3(船只),wRation和hRation表示判断的阈值条件newId--新目标的id
# yyy = channel2_post_process([preds], pars) # 送入后处理函数
#
# print('after :\n ', yyy)