urban_management/DMPRUtils/DMPR_process.py

202 lines
8.2 KiB
Python
Raw Permalink Normal View History

2023-08-17 11:59:31 +08:00
import math
import os
2023-09-19 16:02:42 +08:00
import time
2023-08-17 11:59:31 +08:00
from collections import namedtuple
import cv2
import numpy as np
import torch
from torchvision.transforms import ToTensor
from DMPRUtils.model import DirectionalPointDetector
from conf import config
from utils.datasets import letterbox
from utils.general import clip_coords
from utils.torch_utils import select_device
MarkingPoint = namedtuple('MarkingPoint', ['x', 'y', 'direction', 'shape'])
def plot_points(image, pred_points, line_thickness=3):
"""Plot marking points on the image."""
2023-08-30 17:37:53 +08:00
if pred_points.size:
2023-08-17 11:59:31 +08:00
tl = line_thickness or round(0.002 * (image.shape[0] + image.shape[1]) / 2) + 1 # line/font thickness
tf = max(tl - 1, 1) # font thickness
for conf, *point in pred_points:
p0_x, p0_y = int(point[0]), int(point[1])
cos_val = math.cos(point[2])
sin_val = math.sin(point[2])
p1_x = int(p0_x + 20 * cos_val * tl)
p1_y = int(p0_y + 20 * sin_val * tl)
p2_x = int(p0_x - 10 * sin_val * tl)
p2_y = int(p0_y + 10 * cos_val * tl)
p3_x = int(p0_x + 10 * sin_val * tl)
p3_y = int(p0_y - 10 * cos_val * tl)
cv2.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255), thickness=tl)
cv2.putText(image, str(float(conf)), (p0_x, p0_y), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), thickness=tf)
if point[3] > 0.5:
cv2.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255), thickness=tl)
else:
cv2.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255), thickness=tf)
def preprocess_image(image):
"""Preprocess numpy image to torch tensor."""
if image.shape[0] != 640 or image.shape[1] != 640:
image = cv2.resize(image, (640, 640))
return torch.unsqueeze(ToTensor()(image), 0)
def non_maximum_suppression(pred_points):
"""Perform non-maxmum suppression on marking points."""
suppressed = [False] * len(pred_points)
for i in range(len(pred_points) - 1):
for j in range(i + 1, len(pred_points)):
i_x = pred_points[i][1].x
i_y = pred_points[i][1].y
j_x = pred_points[j][1].x
j_y = pred_points[j][1].y
# 0.0625 = 1 / 16
if abs(j_x - i_x) < 0.0625 and abs(j_y - i_y) < 0.0625:
idx = i if pred_points[i][0] < pred_points[j][0] else j
suppressed[idx] = True
if any(suppressed):
unsupres_pred_points = []
for i, supres in enumerate(suppressed):
if not supres:
unsupres_pred_points.append(pred_points[i])
return unsupres_pred_points
return pred_points
def get_predicted_points(prediction, thresh):
"""Get marking points from one predicted feature map."""
assert isinstance(prediction, torch.Tensor)
predicted_points = []
prediction = prediction.detach().cpu().numpy()
for i in range(prediction.shape[1]):
for j in range(prediction.shape[2]):
if prediction[0, i, j] >= thresh:
xval = (j + prediction[2, i, j]) / prediction.shape[2]
yval = (i + prediction[3, i, j]) / prediction.shape[1]
# if not (config.BOUNDARY_THRESH <= xval <= 1-config.BOUNDARY_THRESH
# and config.BOUNDARY_THRESH <= yval <= 1-config.BOUNDARY_THRESH):
# continue
cos_value = prediction[4, i, j]
sin_value = prediction[5, i, j]
direction = math.atan2(sin_value, cos_value)
marking_point = MarkingPoint(
xval, yval, direction, prediction[1, i, j])
predicted_points.append((prediction[0, i, j], marking_point))
return non_maximum_suppression(predicted_points)
def get_predicted_points2(prediction, thresh):
"""Get marking points from one predicted feature map."""
assert isinstance(prediction, torch.Tensor)
# predicted_points = []
# prediction = prediction.detach().cpu().numpy()
# for i in range(prediction.shape[1]):
# for j in range(prediction.shape[2]):
# if prediction[0, i, j] >= thresh:
# xval = (j + prediction[2, i, j]) / prediction.shape[2]
# yval = (i + prediction[3, i, j]) / prediction.shape[1]
# # if not (config.BOUNDARY_THRESH <= xval <= 1-config.BOUNDARY_THRESH
# # and config.BOUNDARY_THRESH <= yval <= 1-config.BOUNDARY_THRESH):
# # continue
# cos_value = prediction[4, i, j]
# sin_value = prediction[5, i, j]
# direction = math.atan2(sin_value, cos_value)
# marking_point = MarkingPoint(
# xval, yval, direction, prediction[1, i, j])
# predicted_points.append((prediction[0, i, j], marking_point))
prediction = prediction.permute(1, 2, 0).contiguous() # prediction (20, 20, 6)
height = prediction.shape[0]
width = prediction.shape[1]
j = torch.arange(prediction.shape[1], device=prediction.device).float().repeat(prediction.shape[0], 1).unsqueeze(dim=2)
i = torch.arange(prediction.shape[0], device=prediction.device).float().view(prediction.shape[0], 1).repeat(1, prediction.shape[1]).unsqueeze(dim=2)
prediction = torch.cat((prediction, j, i), dim=2)
# 过滤小于thresh的置信度
prediction = prediction[prediction[..., 0] > thresh]
prediction[..., 2] = (prediction[..., 2] + prediction[..., 6]) / width
prediction[..., 3] = (prediction[..., 3] + prediction[..., 7]) / height
direction = torch.atan2(prediction[..., 5], prediction[..., 4])
prediction = torch.stack((prediction[..., 0], prediction[..., 2], prediction[..., 3], direction, prediction[..., 1]), dim=1)
# return non_maximum_suppression(predicted_points)
return prediction
def detect_marking_points(detector, image, thresh, device):
"""Given image read from opencv, return detected marking points."""
2023-11-02 13:14:56 +08:00
# t1 = time.time()
# torch.cuda.synchronize(device)
2023-08-17 11:59:31 +08:00
prediction = detector(preprocess_image(image).to(device))
2023-11-02 13:14:56 +08:00
# torch.cuda.synchronize(device)
# t2 = time.time()
# print(f'detector: {t2 - t1:.3f}s')
2023-08-17 11:59:31 +08:00
return get_predicted_points2(prediction[0], thresh)
def scale_coords2(img1_shape, coords, img0_shape, ratio_pad=None):
# Rescale coords (xy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
# 百分比xy转换为实际xy
height, width = img1_shape
coords[:, 0] = torch.round(width * coords[:, 0] - 0.5)
coords[:, 1] = torch.round(height * coords[:, 1] - 0.5)
coords[:, 0] -= pad[0] # x padding
coords[:, 1] -= pad[1] # y padding
coords[:, :3] /= gain
#恢复成原始图片尺寸
coords[:, 0].clamp_(0, img0_shape[1])
coords[:, 1].clamp_(0, img0_shape[0])
return coords
def DMPR_process(img0, model, device, args):
height, width, _ = img0.shape
img, ratio, (dw, dh) = letterbox(img0, args.dmprimg_size, auto=False)
det = detect_marking_points(model, img, args.dmpr_thresh, device)
# if not pred:
# return torch.tensor([])
# # 由list转为tensor
# det = torch.tensor([[conf, *tup] for conf, tup in pred])
if len(det):
det[:, 1:3] = scale_coords2(img.shape[:2], det[:, 1:3], img0.shape)
# conf, x, y, θ, shape
return det
if __name__ == '__main__':
impath = r'I:\zjc\weiting1\Images'
file = 'DJI_0001_8.jpg'
imgpath = os.path.join(impath, file)
img0 = cv2.imread(imgpath)
device_ = '0'
device = select_device(device_)
args = config.get_parser_for_inference().parse_args()
model = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
weights = r"E:\pycharmProject\DMPR-PS\weights\dp_detector_499.pth"
model.load_state_dict(torch.load(weights))
det = DMPR_process(img0, model, device, args)
plot_points(img0, det)
cv2.imwrite(file, img0, [int(cv2.IMWRITE_JPEG_QUALITY), 100])