import math import torch def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int): device_ = yolo_det.device # dmpr_det内容为conf, x, y, θ, shape if dmpr_det.device != device_: dmpr_det = dmpr_det.to(device_) # 创建yolo_det_clone内容为x1, y1, x2, y2, conf, cls, unlabel (unlabel代表该类是否需要忽略,0:不忽略 其他:忽略) yolo_det_clone = yolo_det.clone().detach() tmp_0_tensor = torch.zeros([len(yolo_det), 1], device=device_) yolo_det_clone = torch.cat([yolo_det_clone, tmp_0_tensor], dim=1) # cls为需要计算的类别 yolo_det = yolo_det[yolo_det[:, -1] == cls] # new_yolo_det为膨胀后数据,内容为x1, y1, x2, y2, flag (flag代表膨胀后车位内是否包含角点 且 与角点方向差值小于90度, 其值为第一个满足条件的角点索引) new_yolo_det = torch.zeros([len(yolo_det), 5], device=device_) # yolo框膨胀,长的边两边各膨胀0.4倍总长,短的边两边各膨胀0.2倍总长 x_length = yolo_det[:, 2] - yolo_det[:, 0] #x2-x1 y_length = yolo_det[:, 3] - yolo_det[:, 1] #y2-y1 # x, y哪个方向差值大哪个方向膨胀的多 x_dilate_coefficient = ((x_length > y_length).int() + 1)*0.2 y_dilate_coefficient = ((~(x_length > y_length)).int() + 1)*0.2 # 膨胀 new_yolo_det[:, 0] = torch.round(yolo_det[:, 0] - x_dilate_coefficient * x_length).clamp_(0, img_shape[1]) #x1 膨胀 new_yolo_det[:, 1] = torch.round(yolo_det[:, 1] - y_dilate_coefficient * y_length).clamp_(0, img_shape[0]) #y1 膨胀 new_yolo_det[:, 2] = torch.round(yolo_det[:, 2] + x_dilate_coefficient * x_length).clamp_(0, img_shape[1]) #x2 膨胀 new_yolo_det[:, 3] = torch.round(yolo_det[:, 3] + y_dilate_coefficient * y_length).clamp_(0, img_shape[0]) #y2 膨胀 # 判断膨胀后yolo框包含角点关系 && 包含角点的时候计算水平框中心点与角点的角度关系 # for i in range(0, len(new_yolo_det)): # for j in range(0, len(dmpr_det)): # if new_yolo_det[i, 4] == 0: # [x_p, y_p] = dmpr_det[j, 1:3] # [x1, y1, x2, y2] = new_yolo_det[i, :4] # x_c = (x1+x2)/2 # y_c = (y1+y2)/2 # if (x_p >= x1) and (x_p <= x2) and (y_p >= y1) and (y_p <= y2): # direction1 = math.atan2(y_c-y_p, x_c-x_p) / math.pi * 180 # direction2 = dmpr_det[j, 3] / math.pi * 180 # ang_diff = direction1 - direction2 # # direction ∈ (-180, 180) 若角差大于180,需算补角 # if (ang_diff >= -90) and (ang_diff <= 90): # new_yolo_det[i, 4] = j + 1 #为防止 j = 0 时赋值,故作 +1 操作 # elif (ang_diff > 180) and (360 - ang_diff <= 90): # new_yolo_det[i, 4] = j + 1 # elif (ang_diff < -180) and (360 + ang_diff <= 90): # new_yolo_det[i, 4] = j + 1 m, n = len(new_yolo_det), len(dmpr_det) if not m or not n: return yolo_det_clone, new_yolo_det new_yolo = new_yolo_det.unsqueeze(dim=1).repeat(1, n, 1) # 扩展为 (m , n, 5) dmpr_det = dmpr_det.unsqueeze(dim=0).repeat(m, 1, 1) yolo_dmpr = torch.cat((new_yolo, dmpr_det), dim=2) # (m, n, 10) x_p, y_p = yolo_dmpr[..., 6], yolo_dmpr[..., 7] x1, y1, x2, y2 = yolo_dmpr[..., 0], yolo_dmpr[..., 1], yolo_dmpr[..., 2], yolo_dmpr[..., 3] x_c, y_c = (x1+x2)/2, (y1+y2)/2 direction1 = torch.atan2(y_c - y_p, x_c - x_p) / math.pi * 180 direction2 = yolo_dmpr[..., 8] / math.pi * 180 ang_diff = direction1 - direction2 # 判断膨胀后yolo框包含角点关系 & & 包含角点的时候计算水平框中心点与角点的角度关系 # direction ∈ (-180, 180) 若角差大于180,需算补角 mask = (x_p >= x1) & (x_p <= x2) & (y_p >= y1) & (y_p <= y2) & \ (((ang_diff >= -90) & (ang_diff <= 90)) | ((ang_diff > 180) & ((360 - ang_diff) <= 90)) | (((ang_diff) < -180) & ((360 + ang_diff) <= 90))) res = torch.sum(mask, dim=1).float() # 索引两次更新tensor test1 # yolo_det_clone[yolo_det_clone[:, -2] == cls][:, -1] = new_yolo_det[:, 4] # 索引两次更新tensor test2 # a = [x for x in torch.arange(len(new_yolo_det))] # b = [6 for _ in torch.arange(len(new_yolo_det))] # index = (torch.LongTensor(a), torch.LongTensor(b)) # value = torch.tensor(new_yolo_det[:, 4], device=device_) # yolo_det_clone[yolo_det_clone[:, -2] == cls].index_put_(index, value) yolo_det_clone[yolo_det_clone[:, -2] == cls, -1] = res return yolo_det_clone, new_yolo_det