AIlib2/DMPRUtils/DMPR_process.py

231 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import os
import time
from collections import namedtuple
import cv2
import numpy as np
import torch
from torchvision.transforms import ToTensor
from DMPRUtils.model import DirectionalPointDetector
from utils.datasets import letterbox
from utils.general import clip_coords
from utils.torch_utils import select_device
#from DMPRUtils.trtUtils import TrtForwardCase
#import segutils.trtUtils.segTrtForward as TrtForwardCase
from segutils.trtUtils import segTrtForward
MarkingPoint = namedtuple('MarkingPoint', ['x', 'y', 'direction', 'shape'])
def plot_points(image, pred_points, line_thickness=3):
"""Plot marking points on the image."""
if len(pred_points):
tl = line_thickness or round(0.002 * (image.shape[0] + image.shape[1]) / 2) + 1 # line/font thickness
tf = max(tl - 1, 1) # font thickness
for conf, *point in pred_points:
p0_x, p0_y = int(point[0]), int(point[1])
cos_val = math.cos(point[2])
sin_val = math.sin(point[2])
p1_x = int(p0_x + 20 * cos_val * tl)
p1_y = int(p0_y + 20 * sin_val * tl)
p2_x = int(p0_x - 10 * sin_val * tl)
p2_y = int(p0_y + 10 * cos_val * tl)
p3_x = int(p0_x + 10 * sin_val * tl)
p3_y = int(p0_y - 10 * cos_val * tl)
cv2.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255), thickness=tl)
cv2.putText(image, str(float(conf)), (p0_x, p0_y), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), thickness=tf)
if point[3] > 0.5:
cv2.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255), thickness=tl)
else:
cv2.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255), thickness=tf)
def preprocess_image(image):
"""Preprocess numpy image to torch tensor."""
if image.shape[0] != 640 or image.shape[1] != 640:
image = cv2.resize(image, (640, 640))
return torch.unsqueeze(ToTensor()(image), 0)
def non_maximum_suppression(pred_points):
"""Perform non-maxmum suppression on marking points."""
t1 = time.time()
suppressed = [False] * len(pred_points)
for i in range(len(pred_points) - 1):
for j in range(i + 1, len(pred_points)):
i_x = pred_points[i][1].x
i_y = pred_points[i][1].y
j_x = pred_points[j][1].x
j_y = pred_points[j][1].y
# 0.0625 = 1 / 16
if abs(j_x - i_x) < 0.0625 and abs(j_y - i_y) < 0.0625:
idx = i if pred_points[i][0] < pred_points[j][0] else j
suppressed[idx] = True
if any(suppressed):
unsupres_pred_points = []
for i, supres in enumerate(suppressed):
if not supres:
unsupres_pred_points.append(pred_points[i])
return unsupres_pred_points
t2 = time.time()
print(f'nms: {t2 - t1:.3f}s')
return pred_points
def ms(t2,t1):
return ('%.1f '%( (t2-t1)*1000 ) )
def get_predicted_points(prediction, thresh):
"""Get marking points from one predicted feature map."""
t1 = time.time()
assert isinstance(prediction, torch.Tensor)
prediction = prediction.permute(1, 2, 0).contiguous() # prediction (20, 20, 6)
height = prediction.shape[0]
width = prediction.shape[1]
j = torch.arange(prediction.shape[1], device=prediction.device).float().repeat(prediction.shape[0], 1).unsqueeze(dim=2)
i = torch.arange(prediction.shape[0], device=prediction.device).float().view(prediction.shape[0], 1).repeat(1,prediction.shape[1]).unsqueeze(dim=2)
prediction = torch.cat((prediction, j, i), dim=2).view(-1, 8).contiguous()
t2 = time.time()
# 过滤小于thresh的置信度
mask = prediction[..., 0] > thresh
t3 = time.time()
prediction = prediction[mask]
t4 = time.time()
prediction[..., 2] = (prediction[..., 2] + prediction[..., 6]) / width
prediction[..., 3] = (prediction[..., 3] + prediction[..., 7]) / height
direction = torch.atan2(prediction[..., 5], prediction[..., 4])
prediction = torch.stack((prediction[..., 0], prediction[..., 2], prediction[..., 3], direction, prediction[..., 1]), dim=1)
t5 = time.time()
timeInfo = 'rerange:%s scoreFilter:%s , getMask:%s stack:%s '%( ms(t2,t1),ms(t3,t2),ms(t4,t3),ms(t5,t4) )
#print('-'*20,timeInfo)
return prediction,timeInfo
def get_predicted_points_np(prediction, thresh):
"""Get marking points from one predicted feature map."""
t1 = time.time()
prediction = prediction.permute(1, 2, 0).contiguous() # prediction (20, 20, 6)
t1_1 = time.time()
prediction = prediction.cpu().detach().numpy()
t1_2 = time.time()
height,width = prediction.shape[0:2]
i,j = np.mgrid[0:height, 0:width]
i = np.expand_dims(i,axis=2);j = np.expand_dims(j,axis=2)
#print('##line112:',i.shape,j.shape,prediction.shape)
prediction = np.concatenate( (prediction,i,j),axis=2 )
prediction = prediction.reshape(-1,8)
t2 = time.time()
mask = prediction[..., 0] > thresh
t3 = time.time()
prediction = prediction[mask]
t4 = time.time()
prediction[..., 2] = (prediction[..., 2] + prediction[..., 6]) / width
prediction[..., 3] = (prediction[..., 3] + prediction[..., 7]) / height
direction = np.arctan(prediction[..., 5:6], prediction[..., 4:5])
#print('-'*20,prediction.shape,direction.shape)
prediction = np.hstack((prediction[:, 0:1], prediction[:, 2:3], prediction[:, 3:4], direction, prediction[:, 1:2]))
#print('-line126:','-'*20,type(prediction),prediction.shape)
t5 = time.time()
timeInfo = 'permute:%s Tocpu:%s rerange:%s scoreFilter:%s , getMask:%s stack:%s '%( ms(t1_1,t1) , ms(t1_2,t1_1),ms(t2,t1_2),ms(t3,t2),ms(t4,t3),ms(t5,t4) )
print('-'*20,timeInfo,prediction.shape)
return prediction
def detect_marking_points(detector, image, thresh, device,modelType='pth'):
"""Given image read from opencv, return detected marking points."""
t1 = time.time()
image_preprocess = preprocess_image(image).to(device)
if modelType=='pth':
prediction = detector(image_preprocess)
#print(prediction)
elif modelType=='trt':
a=0
prediction = segTrtForward(detector,[image_preprocess ])
#print(prediction)
torch.cuda.synchronize(device)
t2 = time.time()
rets,timeInfo = get_predicted_points(prediction[0], thresh)
string_t2 = ' infer:%s postprocess:%s'%(ms(t2,t1),timeInfo)
return rets
def scale_coords2(img1_shape, coords, img0_shape, ratio_pad=None):
# Rescale coords (xy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
# 百分比xy转换为实际xy
height, width = img1_shape
if isinstance(coords, torch.Tensor):
coords[:, 0] = torch.round(width * coords[:, 0] - 0.5)
coords[:, 1] = torch.round(height * coords[:, 1] - 0.5)
else:
coords[:, 0] = (width * coords[:, 0] + 0.5).astype(np.int32)
coords[:, 1] = (height * coords[:, 1] + 0.5).astype(np.int32)
coords[:, 0] -= pad[0] # x padding
coords[:, 1] -= pad[1] # y padding
coords[:, :3] /= gain
#恢复成原始图片尺寸
if isinstance(coords, torch.Tensor):
coords[:, 0].clamp_(0, img0_shape[1])
coords[:, 1].clamp_(0, img0_shape[0])
else:
coords[:, 0] = np.clip( coords[:, 0], 0,img0_shape[1] )
coords[:, 1] = np.clip( coords[:, 1], 0,img0_shape[0] )
return coords
def DMPR_process(img0, model, device, DMPRmodelPar):
t0 = time.time()
height, width, _ = img0.shape
img, ratio, (dw, dh) = letterbox(img0, DMPRmodelPar['dmprimg_size'], auto=False)
t1 = time.time()
#print('###line188:', height, width, img.shape)
det = detect_marking_points(model, img, DMPRmodelPar['dmpr_thresh'], device,modelType=DMPRmodelPar['modelType'])
t2 = time.time()
if len(det):
det[:, 1:3] = scale_coords2(img.shape[:2], det[:, 1:3], img0.shape)
t3 = time.time()
timeInfos = 'dmpr:%1.f (lettbox:%.1f dectect:%.1f scaleBack:%.1f) '%( (t3-t0)*1000,(t1-t0)*1000,(t2-t1)*1000,(t3-t2)*1000, )
return det,timeInfos
if __name__ == '__main__':
impath = r'I:\zjc\weiting1\Images'
file = 'DJI_0001_8.jpg'
imgpath = os.path.join(impath, file)
img0 = cv2.imread(imgpath)
device_ = '0'
device = select_device(device_)
args = config.get_parser_for_inference().parse_args()
model = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
weights = r"E:\pycharmProject\DMPR-PS\weights\dp_detector_499.pth"
model.load_state_dict(torch.load(weights))
det = DMPR_process(img0, model, device, args)
plot_points(img0, det)
cv2.imwrite(file, img0, [int(cv2.IMWRITE_JPEG_QUALITY), 100])