AIlib2/utilsK/channel2postUtils.py

307 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
from pathlib import Path
import math
import cv2
import numpy as np
import torch
FILE = Path(__file__).absolute()
#sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
'''
修改说明:
1、pars中增加了recScale参数。船舶判断是否悬挂国旗时需将船舶检测框乘以扩大系数imgScale后与国旗中心点坐标比较。
pars={'imgSize':(imgwidth,imgheight),'wRation':1/6.0,'hRation':1/6.0,'smallId':0,'bigId':3,'newId':4,'recScale':1.2}
2、增加expand_rect(preds_boat, recScale, imgSize)函数在图像范围内将矩形框扩大recScale倍数。
3、增加或修改以下两行
preds_boat_flag_expand=expand_rect(preds_boat_flag[i],pars['recScale'],pars['imgSize']) #新增!
if point_in_rectangle(preds_flag,preds_boat_flag_expand)>=1: #新增后修改!
'''
def channel2_post_process(predsList,pars):
print('----line24:',predsList)
#pars={'imgSize':(imgwidth,imgheight),'wRation':1/6.0,'hRation':1/6.0,'smallId':0,'bigId':3,'newId':4,'recScale':1.2}
'''
后处理程序将检测结果中未悬挂国旗的船只其类别改为4'unflagged_ship'
最终类别汇总如下,
['flag', 'buoy', 'shipname', 'ship','unflagged_ship','uncover']=[0,1,2,3,4,5]
输入:
preds 一张图像的检测结果为嵌套列表tensor包括x_y_x_y_conf_class
imgwidth,imgheight 图像的原始宽度及长度
输出:检测结果(将其中未悬挂国旗的显示)
'''
preds = torch.tensor(predsList[0])
preds=preds.tolist()
preds = filter_detection_results(preds,pars)
preds=[[*sublist[:-1], int(sublist[-1])] for sublist in preds] #类别从浮点型转为整型
#设置空的列表
output_detection=[] #存放往接口传的类别
#1、判断类别中哪些有船取出船检测结果并取出国旗检测结果。
# output_detection.append[] 这里将船和国旗以外的类别加进去
preds_boat=[]
preds_flag=[]
#jcq 增加封仓
preds_uncover = []
# 1、处理未封仓
preds = filter_detection_results(preds,pars)
for i in range(len(preds)):
if preds[i][5]==pars['boatId']: #识别为船
preds_boat.append(preds[i])
elif preds[i][5]==pars['flagId']: #识别为国旗
preds_flag.append(preds[i])
# output_detection.append(preds[i])
#jcq:
elif preds[i][5]==pars['uncoverId']: #未封仓
preds_uncover.append(preds[i])
# output_detection.append(preds[i])
else:
output_detection.append(preds[i])
# pass
# return output_detection
#2、船尺寸与图像比较其中长或宽有一个维度超过图像宽高平均值的1/3启动国旗检测
#①if 判断判断超过1/3的则取出这些大船进一步判断是否悬挂国旗
#不超过1/3的则output_detection.append[]
boat_uncover = preds_boat+preds_uncover
for i in range(len(boat_uncover)):
length_bbx,width_bbx=get_rectangle_dimensions(boat_uncover[i])
length_bbx, width_bbx=int(length_bbx),int(width_bbx)
if length_bbx>(pars['imgSize'][0]+pars['imgSize'][1])* pars['hRation'] or width_bbx>(pars['imgSize'][0]+pars['imgSize'][1])*pars['wRation']:
boat_uncover[i] = unflag(boat_uncover[i], preds_flag, pars)
return output_detection + boat_uncover
def unflag(boat_uncover,preds_flag,pars):
preds_boat_flag_expand = expand_rect(boat_uncover, pars['recScale'], pars['imgSize']) # 新增!
if not point_in_rectangle(preds_flag, preds_boat_flag_expand) >= 1: # 新增后修改!
if boat_uncover[5] == pars['uncoverId']:
boat_uncover[5] = pars['unflagAndcoverId'] # 将类别标签改为6未挂国旗且未封仓
else:
boat_uncover[5] = pars['unflagId'] # 将类别标签改为4即为未悬挂国旗的船只
return boat_uncover
def center_coordinate(boundbxs):
'''
根据检测矩形框,得到其矩形长度和宽度
输入两个对角坐标xyxy
输出矩形框重点坐标xy
'''
boundbxs_x1 = boundbxs[0]
boundbxs_y1 = boundbxs[1]
boundbxs_x2 = boundbxs[2]
boundbxs_y2 = boundbxs[3]
center_x = 0.5 * (boundbxs_x1 + boundbxs_x2)
center_y = 0.5 * (boundbxs_y1 + boundbxs_y2)
return center_x, center_y
def get_rectangle_dimensions(boundbxs):
'''
根据检测矩形框,得到其矩形长度和宽度
输入两个对角坐标xyxy
输出矩形框四个角点坐标以contours顺序。
'''
# 计算两点之间的水平距离
width = math.fabs(boundbxs[2] - boundbxs[0])
# 计算两点之间的垂直距离
height = math.fabs(boundbxs[3]- boundbxs[1])
return width, height
def fourcorner_coordinate(boundbxs):
'''
通过矩形框对角xyxy坐标得到矩形框轮廓
输入两个对角坐标xyxy
输出矩形框四个角点坐标以contours顺序。
'''
boundbxs_x1 = boundbxs[0]
boundbxs_y1 = boundbxs[1]
boundbxs_x2 = boundbxs[2]
boundbxs_y2 = boundbxs[3]
wid = boundbxs_x2 - boundbxs_x1
hei = boundbxs_y2 - boundbxs_y1
boundbxs_x3 = boundbxs_x1 + wid
boundbxs_y3 = boundbxs_y1
boundbxs_x4 = boundbxs_x1
boundbxs_y4 = boundbxs_y1 + hei
contours_rec = [[boundbxs_x1, boundbxs_y1], [boundbxs_x3, boundbxs_y3], [boundbxs_x2, boundbxs_y2],
[boundbxs_x4, boundbxs_y4]]
return contours_rec
def point_in_rectangle(preds_flag,preds_boat_flag):
'''
遍历所有国旗坐标,判断落在检测框中的数量
输入:
preds_flag 国旗类别的检测结果列表
preds_boat_flag 待判定船只的检测结果(单个船只)
输出:落入检测框的国旗数量
'''
iii=0
boat_contour=fourcorner_coordinate(preds_boat_flag)
boat_contour=np.array(boat_contour,dtype=np.float32)
for i in range(len(preds_flag)):
center_x, center_y = center_coordinate(preds_flag[i])
if cv2.pointPolygonTest(boat_contour, (center_x, center_y), False)==1:
iii+=1
else:
pass
return iii
def expand_rect(preds_boat, recScale, imgSize):
'''
在图像范围内将矩形框扩大recScale倍数。
输入:
preds_boat 国旗类别的检测结果列表 xyxy_conf_class
imgSize 从pars传来的元组
输出调整后的preds_boat
'''
# preds_boat_1=preds_boat
preds_boat_1=[x for x in preds_boat]
x1, y1 = preds_boat[0],preds_boat[1]
x2, y2 = preds_boat[2],preds_boat[3]
width = x2 - x1
height = y2 - y1
# 计算新的宽度和高度
new_width = width * recScale
new_height = height * recScale
# 计算新的对角坐标
new_x1 = max(x1 - (new_width - width) / 2, 0) # 确保不会超出左边界
new_y1 = max(y1 - (new_height - height) / 2, 0) # 确保不会超出上边界
new_x2 = min(x2 + (new_width - width) / 2, imgSize[0]) # 图像宽度是imgSize[0]
new_y2 = min(y2 + (new_height - height) / 2, imgSize[1]) # 图像高度是imgSize[1]
preds_boat_1[0]=new_x1
preds_boat_1[1]=new_y1
preds_boat_1[2]=new_x2
preds_boat_1[3]=new_y2
return preds_boat_1
###jcq : 增加封仓后处理函数
def filter_detection_results(results, par):
target_cls = par['target_cls'] # 船只
filter_cls = par['filter_cls'] # 非封仓
# 分离处理与非处理的结果
non_process = [box for box in results if box[5] not in {target_cls, filter_cls}]
to_process = [box for box in results if box[5] in {target_cls, filter_cls}]
# 提取目标类别和过滤类别的检测框
class_target = [box for box in to_process if box[5] == target_cls] # 船只
class_filter = [box for box in to_process if box[5] == filter_cls] # 非封仓
# 处理过滤类别(映射条件)
for i in range(len(class_target)):
t_box = class_target[i]
if any( # 检查是否在任意目标框内部
(t_box[0] <= f_box[0] and
t_box[1] <= f_box[1] and
t_box[2] >= f_box[2] and
t_box[3] >= f_box[3])
for f_box in class_filter
):
class_target[i][5] = par['uncoverId'] # 映射类别4->5
# 合并结果(保留非处理类别)
return non_process + class_target
def filter_detection_results_uncover(results, par):
target_cls = par['target_cls']
filter_cls = par['filter_cls']
# 分离处理与非处理的结果
non_process = [box for box in results if box[5] not in {target_cls, filter_cls}]
to_process = [box for box in results if box[5] in {target_cls, filter_cls}]
# 提取目标类别和过滤类别的检测框
class_target = [box for box in to_process if box[5] == target_cls]
class_filter = [box for box in to_process if box[5] == filter_cls]
processed = []
# 处理过滤类别(映射条件)
for f_box in class_filter:
if any( # 检查是否在任意目标框内部
(f_box[0] >= t_box[0] and
f_box[1] >= t_box[1] and
f_box[2] <= t_box[2] and
f_box[3] <= t_box[3])
for t_box in class_target
):
new_box = f_box.copy()
new_box[5] = 5 # 映射类别4->5
processed.append(new_box)
# 保留所有目标类别检测框
processed += class_target
# 合并结果(保留非处理类别)
return non_process + processed
if __name__ == "__main__":
# 对应DJI_20230306140129_0001_Z_165.jpg检测结果
# preds=[[6.49000e+02, 2.91000e+02, 1.07900e+03, 7.33000e+02, 9.08165e-01, 3.00000e+00],
# [8.11000e+02, 2.99000e+02, 1.31200e+03, 7.65000e+02, 8.61268e-01, 3.00000e+00],
# [7.05000e+02, 1.96000e+02, 7.19000e+02, 2.62000e+02, 5.66877e-01, 0.00000e+00]]
# 对应DJI_20230306152702_0001_Z_562.jpg检测结果
preds=[[7.62000e+02, 7.14000e+02, 1.82800e+03, 9.51000e+02, 9.00902e-01, 3.00000e+00],
[2.00000e+01, 3.45000e+02, 1.51300e+03, 6.71000e+02, 8.81440e-01, 3.00000e+00],
[8.35000e+02, 8.16000e+02, 8.53000e+02, 8.30000e+02, 7.07651e-01, 0.00000e+00],
[1.35600e+03, 4.56000e+02, 1.42800e+03, 4.94000e+02, 6.70549e-01, 2.00000e+00]]
print('before :\n ',preds)
#preds=torch.tensor(preds) #返回的预测结果
imgwidth=1920
imgheight=1680
pars={'imgSize':(imgwidth,imgheight),'wRation':1/6.0,'hRation':1/6.0,'smallId':0,'bigId':3,'newId':4,'recScale':1.2}
# 'smallId':0(国旗)'bigId':3(船只),wRation和hRation表示判断的阈值条件newId--新目标的id
yyy=channel2_post_process([preds],pars) #送入后处理函数
print('after :\n ',yyy)