2023-08-17 11:59:31 +08:00
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from DMPRUtils.DMPR_process import DMPR_process, plot_points
|
|
|
|
|
from DMPRUtils.model.detector import DirectionalPointDetector
|
2023-09-19 16:02:42 +08:00
|
|
|
from DMPRUtils.yolo_net import Model
|
2023-08-17 11:59:31 +08:00
|
|
|
from DMPR_YOLO.jointUtil import dmpr_yolo
|
2023-11-01 16:53:11 +08:00
|
|
|
from STDCUtils.STDC_process import STDC_process
|
|
|
|
|
from STDCUtils.models.model_stages import BiSeNet
|
|
|
|
|
from STDC_YOLO.yolo_stdc_joint import stdc_yolo
|
2023-08-17 11:59:31 +08:00
|
|
|
from conf import config
|
|
|
|
|
from models.experimental import attempt_load
|
|
|
|
|
from models.yolo_process import yolo_process
|
|
|
|
|
from utils.plots import plot_one_box
|
|
|
|
|
from utils.torch_utils import select_device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
##预先设置的参数
|
2023-09-19 16:02:42 +08:00
|
|
|
device_ = '1' ##选定模型,可选 cpu,'0','1'
|
2023-08-17 11:59:31 +08:00
|
|
|
|
|
|
|
|
##以下参数目前不可改
|
2023-11-01 16:53:11 +08:00
|
|
|
Detweights = 'weights/urbanManagement/yolo/best1023.pt'
|
2023-08-17 11:59:31 +08:00
|
|
|
seg_nclass = 2
|
2023-11-01 16:53:11 +08:00
|
|
|
DMPRweights = "weights/urbanManagement/DMPR/dp_detector_299_1023.pth"
|
2023-08-17 11:59:31 +08:00
|
|
|
conf_thres, iou_thres, classes = 0.25, 0.45, 3
|
|
|
|
|
labelnames = "weights/yolov5/class5/labelnames.json"
|
|
|
|
|
rainbows = [[0, 0, 255], [0, 255, 0], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 129, 0], [255, 0, 127],
|
|
|
|
|
[127, 255, 0], [0, 255, 127], [0, 127, 255], [127, 0, 255], [255, 127, 255], [255, 255, 127],
|
|
|
|
|
[127, 255, 255], [0, 255, 255], [255, 127, 255], [127, 255, 255], [0, 127, 0], [0, 0, 127],
|
|
|
|
|
[0, 255, 255]]
|
|
|
|
|
allowedList = [0, 1, 2, 3]
|
|
|
|
|
|
|
|
|
|
##加载模型,准备好显示字符
|
|
|
|
|
device = select_device(device_)
|
|
|
|
|
|
|
|
|
|
half = device.type != 'cpu' # half precision only supported on CUDA
|
|
|
|
|
# yolov5 model
|
|
|
|
|
model = attempt_load(Detweights, map_location=device)
|
|
|
|
|
if half:
|
|
|
|
|
model.half()
|
|
|
|
|
|
2023-11-01 16:53:11 +08:00
|
|
|
# STDC model
|
|
|
|
|
args = config.get_parser_for_inference().parse_args()
|
|
|
|
|
|
|
|
|
|
# STDC model
|
|
|
|
|
STDC_model = BiSeNet(backbone=args.backbone, n_classes=args.n_classes,
|
|
|
|
|
use_boundary_2=args.use_boundary_2, use_boundary_4=args.use_boundary_4,
|
|
|
|
|
use_boundary_8=args.use_boundary_8, use_boundary_16=args.use_boundary_16,
|
|
|
|
|
use_conv_last=args.use_conv_last).to(device)
|
|
|
|
|
STDC_model.load_state_dict(torch.load(args.respth))
|
|
|
|
|
STDC_model.eval()
|
|
|
|
|
|
2023-08-17 11:59:31 +08:00
|
|
|
# DMPR model
|
|
|
|
|
args = config.get_parser_for_inference().parse_args()
|
2023-09-19 16:02:42 +08:00
|
|
|
# DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
|
|
|
|
|
# DMPRmodel.load_state_dict(torch.load(DMPRweights))
|
|
|
|
|
DMPRmodel = Model(args.cfg, ch=3).to(device)
|
2023-08-17 11:59:31 +08:00
|
|
|
DMPRmodel.load_state_dict(torch.load(DMPRweights))
|
|
|
|
|
|
|
|
|
|
# 图像测试
|
2023-10-31 13:26:17 +08:00
|
|
|
# impth = 'images/input'
|
|
|
|
|
impth = 'images/debug'
|
2023-09-19 16:02:42 +08:00
|
|
|
# impth = '/home/thsw/WJ/zjc/AI_old/images/input_0'
|
2023-10-31 13:26:17 +08:00
|
|
|
# outpth = 'images/output'
|
|
|
|
|
outpth = 'images/debug_out'
|
2023-08-17 11:59:31 +08:00
|
|
|
folders = os.listdir(impth)
|
|
|
|
|
for file in folders:
|
|
|
|
|
imgpath = os.path.join(impth, file)
|
|
|
|
|
img0 = cv2.imread(imgpath)
|
|
|
|
|
assert img0 is not None, 'Image Not Found ' + imgpath
|
|
|
|
|
|
2023-09-19 16:02:42 +08:00
|
|
|
t_start = time.time()
|
2023-08-17 11:59:31 +08:00
|
|
|
# yolo process
|
|
|
|
|
det0 = yolo_process(img0, model, device, args, half)
|
2023-08-30 17:37:53 +08:00
|
|
|
det0 = det0.cpu().detach().numpy()
|
2023-09-19 16:02:42 +08:00
|
|
|
t_yolo = time.time()
|
|
|
|
|
print(f't_yolo. ({t_yolo - t_start:.3f}s)')
|
2023-08-17 11:59:31 +08:00
|
|
|
|
2023-11-01 16:53:11 +08:00
|
|
|
# STDC process
|
|
|
|
|
det2 = STDC_process(img0, STDC_model, device, args.n_classes, args.stdc_scale)
|
|
|
|
|
|
|
|
|
|
# STDC joint yolo
|
|
|
|
|
det0 = stdc_yolo(det2, det0)
|
|
|
|
|
|
2023-08-17 11:59:31 +08:00
|
|
|
# plot所有box
|
|
|
|
|
# for *xyxy, conf, cls in reversed(det0):
|
|
|
|
|
# label = f'{int(cls)} {conf:.2f}'
|
|
|
|
|
# plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
|
|
|
|
|
|
|
|
|
|
# DMPR process
|
|
|
|
|
det1 = DMPR_process(img0, DMPRmodel, device, args)
|
2023-08-30 17:37:53 +08:00
|
|
|
det1 = det1.cpu().detach().numpy()
|
2023-08-17 11:59:31 +08:00
|
|
|
|
2023-09-19 16:02:42 +08:00
|
|
|
t_dmpr = time.time()
|
|
|
|
|
print(f't_dmpr. ({t_dmpr - t_yolo:.3f}s)')
|
2023-08-17 11:59:31 +08:00
|
|
|
|
|
|
|
|
# 绘制角点
|
|
|
|
|
plot_points(img0, det1)
|
|
|
|
|
|
|
|
|
|
# save
|
|
|
|
|
# cv2.imwrite(file, img0)
|
|
|
|
|
|
|
|
|
|
# yolo joint DMPR
|
|
|
|
|
cls = 0 #需要过滤的box类别
|
2023-09-19 16:02:42 +08:00
|
|
|
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio)
|
2023-08-17 11:59:31 +08:00
|
|
|
|
2023-09-19 16:02:42 +08:00
|
|
|
t_joint = time.time()
|
|
|
|
|
print(f't_joint. ({t_joint - t_dmpr:.3f}s)')
|
2023-08-17 11:59:31 +08:00
|
|
|
|
|
|
|
|
# t_end = time.time()
|
|
|
|
|
# print(f'Done. ({t_end - t_start:.3f}s)')
|
|
|
|
|
# 绘制膨胀box
|
|
|
|
|
for *xyxy, flag in dilate_box:
|
|
|
|
|
plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2)
|
|
|
|
|
#
|
|
|
|
|
# # 绘制删除满足 在膨胀框内 && 角度差小于90度 的box
|
|
|
|
|
for *xyxy, conf, cls, flag in reversed(joint_det):
|
|
|
|
|
if flag == 0:
|
2023-08-30 17:37:53 +08:00
|
|
|
# label = f'{int(cls)} {conf:.2f}'
|
|
|
|
|
label = None
|
2023-08-17 11:59:31 +08:00
|
|
|
plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
|
|
|
|
|
|
|
|
|
|
# save
|
|
|
|
|
save_path = os.path.join(outpth, file)
|
|
|
|
|
cv2.imwrite(save_path, img0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
main()
|