tensor
This commit is contained in:
parent
65d533f0a1
commit
098d947752
|
|
@ -4,7 +4,7 @@
|
|||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.9 (torch1.7) (15)" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.8 (yolov5)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
|
|
@ -2,6 +2,13 @@
|
|||
<project version="4">
|
||||
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="th@192.168.11.8:32178 password">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="thsw@192.168.10.11:22 password">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (torch1.7) (15)" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (yolov5)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
|
|
@ -6,6 +6,7 @@ import torch
|
|||
|
||||
from DMPRUtils.DMPR_process import DMPR_process, plot_points
|
||||
from DMPRUtils.model.detector import DirectionalPointDetector
|
||||
from DMPRUtils.yolo_net import Model
|
||||
from DMPR_YOLO.jointUtil import dmpr_yolo
|
||||
from conf import config
|
||||
from models.experimental import attempt_load
|
||||
|
|
@ -16,12 +17,12 @@ from utils.torch_utils import select_device
|
|||
|
||||
def main():
|
||||
##预先设置的参数
|
||||
device_ = '0' ##选定模型,可选 cpu,'0','1'
|
||||
device_ = '1' ##选定模型,可选 cpu,'0','1'
|
||||
|
||||
##以下参数目前不可改
|
||||
Detweights = 'weights/urbanManagement/yolo/best.pt'
|
||||
seg_nclass = 2
|
||||
DMPRweights = "weights/urbanManagement/DMPR/dp_detector_499.pth"
|
||||
DMPRweights = "weights/urbanManagement/DMPR/dp_detector_299.pth"
|
||||
conf_thres, iou_thres, classes = 0.25, 0.45, 3
|
||||
labelnames = "weights/yolov5/class5/labelnames.json"
|
||||
rainbows = [[0, 0, 255], [0, 255, 0], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 129, 0], [255, 0, 127],
|
||||
|
|
@ -41,26 +42,29 @@ def main():
|
|||
|
||||
# DMPR model
|
||||
args = config.get_parser_for_inference().parse_args()
|
||||
DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
|
||||
# DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
|
||||
# DMPRmodel.load_state_dict(torch.load(DMPRweights))
|
||||
DMPRmodel = Model(args.cfg, ch=3).to(device)
|
||||
DMPRmodel.load_state_dict(torch.load(DMPRweights))
|
||||
|
||||
# 图像测试
|
||||
# impth = 'images/input'
|
||||
impth = 'images/debug'
|
||||
# outpth = 'images/output'
|
||||
outpth = 'images/debug_out'
|
||||
impth = 'images/input'
|
||||
# impth = 'images/debug'
|
||||
# impth = '/home/thsw/WJ/zjc/AI_old/images/input_0'
|
||||
outpth = 'images/output'
|
||||
# outpth = 'images/debug_out'
|
||||
folders = os.listdir(impth)
|
||||
for file in folders:
|
||||
imgpath = os.path.join(impth, file)
|
||||
img0 = cv2.imread(imgpath)
|
||||
assert img0 is not None, 'Image Not Found ' + imgpath
|
||||
|
||||
# t_start = time.time()
|
||||
t_start = time.time()
|
||||
# yolo process
|
||||
det0 = yolo_process(img0, model, device, args, half)
|
||||
det0 = det0.cpu().detach().numpy()
|
||||
# t_yolo = time.time()
|
||||
# print(f't_yolo. ({t_yolo - t_start:.3f}s)')
|
||||
t_yolo = time.time()
|
||||
print(f't_yolo. ({t_yolo - t_start:.3f}s)')
|
||||
|
||||
# plot所有box
|
||||
# for *xyxy, conf, cls in reversed(det0):
|
||||
|
|
@ -71,8 +75,8 @@ def main():
|
|||
det1 = DMPR_process(img0, DMPRmodel, device, args)
|
||||
det1 = det1.cpu().detach().numpy()
|
||||
|
||||
# t_dmpr = time.time()
|
||||
# print(f't_dmpr. ({t_dmpr - t_yolo:.3f}s)')
|
||||
t_dmpr = time.time()
|
||||
print(f't_dmpr. ({t_dmpr - t_yolo:.3f}s)')
|
||||
|
||||
# 绘制角点
|
||||
plot_points(img0, det1)
|
||||
|
|
@ -82,10 +86,10 @@ def main():
|
|||
|
||||
# yolo joint DMPR
|
||||
cls = 0 #需要过滤的box类别
|
||||
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls)
|
||||
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio)
|
||||
|
||||
# t_joint = time.time()
|
||||
# print(f't_joint. ({t_joint - t_dmpr:.3f}s)')
|
||||
t_joint = time.time()
|
||||
print(f't_joint. ({t_joint - t_dmpr:.3f}s)')
|
||||
|
||||
# t_end = time.time()
|
||||
# print(f'Done. ({t_end - t_start:.3f}s)')
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import math
|
||||
import os
|
||||
import time
|
||||
from collections import namedtuple
|
||||
|
||||
import cv2
|
||||
|
|
@ -132,7 +133,12 @@ def get_predicted_points2(prediction, thresh):
|
|||
|
||||
def detect_marking_points(detector, image, thresh, device):
|
||||
"""Given image read from opencv, return detected marking points."""
|
||||
t1 = time.time()
|
||||
torch.cuda.synchronize(device)
|
||||
prediction = detector(preprocess_image(image).to(device))
|
||||
torch.cuda.synchronize(device)
|
||||
t2 = time.time()
|
||||
print(f'detector: {t2 - t1:.3f}s')
|
||||
return get_predicted_points2(prediction[0], thresh)
|
||||
|
||||
def scale_coords2(img1_shape, coords, img0_shape, ratio_pad=None):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
|
||||
def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int):
|
||||
def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int, scale_ratio):
|
||||
|
||||
# 创建yolo_det_clone内容为x1, y1, x2, y2, conf, cls, unlabel (unlabel代表该类是否需要忽略,0:不忽略 其他:忽略)
|
||||
yolo_det_clone = yolo_det.copy()
|
||||
|
|
@ -22,8 +22,8 @@ def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int):
|
|||
y_length = yolo_det[:, 3] - yolo_det[:, 1] #y2-y1
|
||||
|
||||
# x, y哪个方向差值大哪个方向膨胀的多
|
||||
x_dilate_coefficient = ((x_length > y_length) + 1)*0.2
|
||||
y_dilate_coefficient = ((~(x_length > y_length)) + 1)*0.2
|
||||
x_dilate_coefficient = ((x_length > y_length) + 1)*scale_ratio
|
||||
y_dilate_coefficient = ((~(x_length > y_length)) + 1)*scale_ratio
|
||||
|
||||
# 膨胀
|
||||
new_yolo_det[:, 0] = np.round(yolo_det[:, 0] - x_dilate_coefficient * x_length).clip(0, img_shape[1]) #x1 膨胀
|
||||
|
|
@ -64,12 +64,24 @@ def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int):
|
|||
|
||||
direction1 = np.arctan2(y_c - y_p, x_c - x_p) / math.pi * 180
|
||||
direction2 = yolo_dmpr[..., 8] / math.pi * 180
|
||||
# direction3 = (direction2 + 90) if (direction2 + 90) <= 180 else (direction2 - 270)
|
||||
direction3 = direction2 + 90 # L形角点另外一个方向
|
||||
direction3[direction3 > 180] -= 360
|
||||
ang_diff = direction1 - direction2
|
||||
ang_diff2 = direction1 - direction3
|
||||
|
||||
# 判断膨胀后yolo框包含角点关系 & & 包含角点的时候计算水平框中心点与角点的角度关系
|
||||
# direction ∈ (-180, 180) 若角差大于180,需算补角
|
||||
# T形角点比较一个方向,L形角点比较两个方向
|
||||
mask = (x_p >= x1) & (x_p <= x2) & (y_p >= y1) & (y_p <= y2) & \
|
||||
(((ang_diff >= -90) & (ang_diff <= 90)) | ((ang_diff > 180) & ((360 - ang_diff) <= 90)) | (((ang_diff) < -180) & ((360 + ang_diff) <= 90)))
|
||||
(((yolo_dmpr[..., 9] <= 0.5) & # T形角点情况
|
||||
(((ang_diff >= -90) & (ang_diff <= 90)) | ((ang_diff > 180) & ((360 - ang_diff) <= 90)) |
|
||||
(((ang_diff) < -180) & ((360 + ang_diff) <= 90)))) |
|
||||
((yolo_dmpr[..., 9] > 0.5) & # L形角点情况
|
||||
(((ang_diff >= -90) & (ang_diff <= 90)) | ((ang_diff > 180) & ((360 - ang_diff) <= 90)) |
|
||||
(((ang_diff) < -180) & ((360 + ang_diff) <= 90))) &
|
||||
(((ang_diff2 >= -90) & (ang_diff2 <= 90)) | ((ang_diff2 > 180) & ((360 - ang_diff2) <= 90)) |
|
||||
(((ang_diff2) < -180) & ((360 + ang_diff2) <= 90)))))
|
||||
|
||||
res = np.sum(mask, axis=1)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,11 +5,14 @@ import argparse
|
|||
# 5: sin(direction)
|
||||
NUM_FEATURE_MAP_CHANNEL = 6
|
||||
|
||||
|
||||
def add_common_arguments(parser):
|
||||
"""Add common arguments for training and inference."""
|
||||
parser.add_argument('--detector_weights', default=r'E:\pycharmProject\DMPR-PS\weights\dp_detector_499.pth',
|
||||
help="The weights of pretrained detector.")
|
||||
parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml',
|
||||
help='model.yaml path')
|
||||
parser.add_argument('--hyp', type=str, default='conf/hyp.scratch.yaml',
|
||||
help='hyperparameters path')
|
||||
parser.add_argument('--depth_factor', type=int, default=32,
|
||||
help="Depth factor.")
|
||||
parser.add_argument('--disable_cuda', action='store_true',
|
||||
|
|
@ -41,6 +44,8 @@ def get_parser_for_inference():
|
|||
help='object confidence threshold')
|
||||
parser.add_argument('--iou-thres', type=float, default=0.45,
|
||||
help='IOU threshold for NMS')
|
||||
parser.add_argument('--scale-ratio', type=float, default=0.5,
|
||||
help='detected box scale ratio')
|
||||
parser.add_argument('--ovlap-thres', type=float, default=0.6, help='overlap threshold for OBS')
|
||||
parser.add_argument('--agnostic-nms', action='store_true',
|
||||
help='class-agnostic NMS')
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue