urban_management/DMPR_YOLO/jointUtil.py

95 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import torch
def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int):
device_ = yolo_det.device
# dmpr_det内容为conf, x, y, θ, shape
if dmpr_det.device != device_:
dmpr_det = dmpr_det.to(device_)
# 创建yolo_det_clone内容为x1, y1, x2, y2, conf, cls, unlabel (unlabel代表该类是否需要忽略0不忽略 其他:忽略)
yolo_det_clone = yolo_det.clone().detach()
tmp_0_tensor = torch.zeros([len(yolo_det), 1], device=device_)
yolo_det_clone = torch.cat([yolo_det_clone, tmp_0_tensor], dim=1)
# cls为需要计算的类别
yolo_det = yolo_det[yolo_det[:, -1] == cls]
# new_yolo_det为膨胀后数据内容为x1, y1, x2, y2, flag (flag代表膨胀后车位内是否包含角点 且 与角点方向差值小于90度, 其值为第一个满足条件的角点索引)
new_yolo_det = torch.zeros([len(yolo_det), 5], device=device_)
# yolo框膨胀长的边两边各膨胀0.4倍总长短的边两边各膨胀0.2倍总长
x_length = yolo_det[:, 2] - yolo_det[:, 0] #x2-x1
y_length = yolo_det[:, 3] - yolo_det[:, 1] #y2-y1
# x, y哪个方向差值大哪个方向膨胀的多
x_dilate_coefficient = ((x_length > y_length).int() + 1)*0.2
y_dilate_coefficient = ((~(x_length > y_length)).int() + 1)*0.2
# 膨胀
new_yolo_det[:, 0] = torch.round(yolo_det[:, 0] - x_dilate_coefficient * x_length).clamp_(0, img_shape[1]) #x1 膨胀
new_yolo_det[:, 1] = torch.round(yolo_det[:, 1] - y_dilate_coefficient * y_length).clamp_(0, img_shape[0]) #y1 膨胀
new_yolo_det[:, 2] = torch.round(yolo_det[:, 2] + x_dilate_coefficient * x_length).clamp_(0, img_shape[1]) #x2 膨胀
new_yolo_det[:, 3] = torch.round(yolo_det[:, 3] + y_dilate_coefficient * y_length).clamp_(0, img_shape[0]) #y2 膨胀
# 判断膨胀后yolo框包含角点关系 && 包含角点的时候计算水平框中心点与角点的角度关系
# for i in range(0, len(new_yolo_det)):
# for j in range(0, len(dmpr_det)):
# if new_yolo_det[i, 4] == 0:
# [x_p, y_p] = dmpr_det[j, 1:3]
# [x1, y1, x2, y2] = new_yolo_det[i, :4]
# x_c = (x1+x2)/2
# y_c = (y1+y2)/2
# if (x_p >= x1) and (x_p <= x2) and (y_p >= y1) and (y_p <= y2):
# direction1 = math.atan2(y_c-y_p, x_c-x_p) / math.pi * 180
# direction2 = dmpr_det[j, 3] / math.pi * 180
# ang_diff = direction1 - direction2
# # direction ∈ -180 180 若角差大于180需算补角
# if (ang_diff >= -90) and (ang_diff <= 90):
# new_yolo_det[i, 4] = j + 1 #为防止 j = 0 时赋值,故作 +1 操作
# elif (ang_diff > 180) and (360 - ang_diff <= 90):
# new_yolo_det[i, 4] = j + 1
# elif (ang_diff < -180) and (360 + ang_diff <= 90):
# new_yolo_det[i, 4] = j + 1
m, n = len(new_yolo_det), len(dmpr_det)
if not m or not n:
return yolo_det_clone, new_yolo_det
new_yolo = new_yolo_det.unsqueeze(dim=1).repeat(1, n, 1) # 扩展为 (m , n, 5)
dmpr_det = dmpr_det.unsqueeze(dim=0).repeat(m, 1, 1)
yolo_dmpr = torch.cat((new_yolo, dmpr_det), dim=2) # (m, n, 10)
x_p, y_p = yolo_dmpr[..., 6], yolo_dmpr[..., 7]
x1, y1, x2, y2 = yolo_dmpr[..., 0], yolo_dmpr[..., 1], yolo_dmpr[..., 2], yolo_dmpr[..., 3]
x_c, y_c = (x1+x2)/2, (y1+y2)/2
direction1 = torch.atan2(y_c - y_p, x_c - x_p) / math.pi * 180
direction2 = yolo_dmpr[..., 8] / math.pi * 180
ang_diff = direction1 - direction2
# 判断膨胀后yolo框包含角点关系 & & 包含角点的时候计算水平框中心点与角点的角度关系
# direction ∈ -180 180 若角差大于180需算补角
mask = (x_p >= x1) & (x_p <= x2) & (y_p >= y1) & (y_p <= y2) & \
(((ang_diff >= -90) & (ang_diff <= 90)) | ((ang_diff > 180) & ((360 - ang_diff) <= 90)) | (((ang_diff) < -180) & ((360 + ang_diff) <= 90)))
res = torch.sum(mask, dim=1).float()
# 索引两次更新tensor test1
# yolo_det_clone[yolo_det_clone[:, -2] == cls][:, -1] = new_yolo_det[:, 4]
# 索引两次更新tensor test2
# a = [x for x in torch.arange(len(new_yolo_det))]
# b = [6 for _ in torch.arange(len(new_yolo_det))]
# index = (torch.LongTensor(a), torch.LongTensor(b))
# value = torch.tensor(new_yolo_det[:, 4], device=device_)
# yolo_det_clone[yolo_det_clone[:, -2] == cls].index_put_(index, value)
yolo_det_clone[yolo_det_clone[:, -2] == cls, -1] = res
return yolo_det_clone, new_yolo_det