|
|
@@ -2,6 +2,7 @@ import os |
|
|
|
import time |
|
|
|
|
|
|
|
import cv2 |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
|
|
|
|
from DMPRUtils.DMPR_process import DMPR_process, plot_points |
|
|
@@ -23,9 +24,9 @@ def main(): |
|
|
|
device_ = '0' ##选定模型,可选 cpu,'0','1' |
|
|
|
|
|
|
|
##以下参数目前不可改 |
|
|
|
Detweights = 'weights/urbanManagement/yolo/best1023.pt' |
|
|
|
Detweights = 'weights/urbanManagement/yolo/best1201.pt' |
|
|
|
seg_nclass = 2 |
|
|
|
DMPRweights = "weights/urbanManagement/DMPR/dp_detector_299_1023.pth" |
|
|
|
DMPRweights = "weights/urbanManagement/DMPR/dp_detector_372_1204.pth" |
|
|
|
conf_thres, iou_thres, classes = 0.25, 0.45, 3 |
|
|
|
labelnames = "weights/yolov5/class5/labelnames.json" |
|
|
|
rainbows = [[0, 0, 255], [0, 255, 0], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 129, 0], [255, 0, 127], |
|
|
@@ -62,7 +63,9 @@ def main(): |
|
|
|
|
|
|
|
# 图像测试 |
|
|
|
# impth = 'images/input' |
|
|
|
impth = 'images/debug' |
|
|
|
# impth = 'images/debug' |
|
|
|
# impth = '/home/thsw/ssd/zjc/cityManagement_test' |
|
|
|
impth = '/home/thsw/WJ/zjc/AI/images/pic2' |
|
|
|
# impth = '/home/thsw/WJ/zjc/AI_old/images/input_0' |
|
|
|
# outpth = 'images/output' |
|
|
|
outpth = 'images/debug_out' |
|
|
@@ -82,7 +85,7 @@ def main(): |
|
|
|
t_stdc = time.time() |
|
|
|
# STDC process |
|
|
|
det2 = STDC_process(img0, STDC_model, device, args.stdc_new_hw) |
|
|
|
# det2[det2 == 1] = 255 |
|
|
|
det2[det2 == 1] = 255 |
|
|
|
t_stdc_inf = time.time() |
|
|
|
print(f't_stdc_inf. ({t_stdc_inf - t_stdc:.3f}s)') |
|
|
|
# STDC joint yolo |
|
|
@@ -102,7 +105,7 @@ def main(): |
|
|
|
print(f't_dmpr. ({t_dmpr - t_yolo:.3f}s)') |
|
|
|
|
|
|
|
# 绘制角点 |
|
|
|
# plot_points(img0, det1) |
|
|
|
plot_points(img0, det1) |
|
|
|
|
|
|
|
# yolo joint DMPR |
|
|
|
cls = 0 #需要过滤的box类别 |
|
|
@@ -114,19 +117,21 @@ def main(): |
|
|
|
t_end = time.time() |
|
|
|
print(f'Done. ({t_end - t_start:.3f}s)') |
|
|
|
# 绘制膨胀box |
|
|
|
# for *xyxy, flag in dilate_box: |
|
|
|
# plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2) |
|
|
|
for *xyxy, flag in dilate_box: |
|
|
|
plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2) |
|
|
|
# # |
|
|
|
# # 绘制删除满足 在膨胀框内 && 角度差小于90度 的box |
|
|
|
# for *xyxy, conf, cls, flag in reversed(joint_det): |
|
|
|
# if flag == 0: |
|
|
|
# # label = f'{int(cls)} {conf:.2f}' |
|
|
|
# label = None |
|
|
|
# plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2) |
|
|
|
for *xyxy, conf, cls, flag in reversed(joint_det): |
|
|
|
if flag == 0: |
|
|
|
# label = f'{int(cls)} {conf:.2f}' |
|
|
|
label = None |
|
|
|
plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2) |
|
|
|
|
|
|
|
# save |
|
|
|
# save_path = os.path.join(outpth, file) |
|
|
|
# cv2.imwrite(save_path, file) |
|
|
|
mask = det2[..., np.newaxis].repeat(3, 2) |
|
|
|
img_seg = 0.3*mask + img0 |
|
|
|
save_path = os.path.join(outpth, file) |
|
|
|
cv2.imwrite(save_path, img_seg) |
|
|
|
|
|
|
|
|
|
|
|
|