From 0240c286982b48b58cc872688586c5dbbdd860d7 Mon Sep 17 00:00:00 2001 From: nyh <175484793@qq.com> Date: Thu, 2 Nov 2023 13:14:56 +0800 Subject: [PATCH] 3 --- .idea/AI.iml | 2 +- .idea/misc.xml | 2 +- AI_example.py | 42 +++++++++++++++++++++------------------ DMPRUtils/DMPR_process.py | 10 +++++----- conf/config.py | 4 ++-- 5 files changed, 32 insertions(+), 28 deletions(-) diff --git a/.idea/AI.iml b/.idea/AI.iml index 513351e..fb9d9f6 100644 --- a/.idea/AI.iml +++ b/.idea/AI.iml @@ -4,7 +4,7 @@ - + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index da088d4..f148f03 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/AI_example.py b/AI_example.py index 307df9e..1661cb1 100644 --- a/AI_example.py +++ b/AI_example.py @@ -20,7 +20,7 @@ from utils.torch_utils import select_device def main(): ##预先设置的参数 - device_ = '1' ##选定模型,可选 cpu,'0','1' + device_ = '0' ##选定模型,可选 cpu,'0','1' ##以下参数目前不可改 Detweights = 'weights/urbanManagement/yolo/best1023.pt' @@ -79,12 +79,16 @@ def main(): t_yolo = time.time() print(f't_yolo. ({t_yolo - t_start:.3f}s)') + t_stdc = time.time() # STDC process - det2 = STDC_process(img0, STDC_model, device, args.n_classes, args.stdc_scale) - + det2 = STDC_process(img0, STDC_model, device, args.stdc_new_hw) + # det2[det2 == 1] = 255 + t_stdc_inf = time.time() + print(f't_stdc_inf. ({t_stdc_inf - t_stdc:.3f}s)') # STDC joint yolo det0 = stdc_yolo(det2, det0) - + t_stdc_yolo = time.time() + print(f't_stdc_joint. ({t_stdc_yolo - t_stdc_inf:.3f}s)') # plot所有box # for *xyxy, conf, cls in reversed(det0): # label = f'{int(cls)} {conf:.2f}' @@ -93,36 +97,36 @@ def main(): # DMPR process det1 = DMPR_process(img0, DMPRmodel, device, args) det1 = det1.cpu().detach().numpy() - + # t_dmpr = time.time() print(f't_dmpr. ({t_dmpr - t_yolo:.3f}s)') # 绘制角点 - plot_points(img0, det1) + # plot_points(img0, det1) # yolo joint DMPR cls = 0 #需要过滤的box类别 joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio, args.border) - + # t_joint = time.time() print(f't_joint. ({t_joint - t_dmpr:.3f}s)') - # t_end = time.time() - # print(f'Done. ({t_end - t_start:.3f}s)') + t_end = time.time() + print(f'Done. ({t_end - t_start:.3f}s)') # 绘制膨胀box - for *xyxy, flag in dilate_box: - plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2) - # + # for *xyxy, flag in dilate_box: + # plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2) + # # # # 绘制删除满足 在膨胀框内 && 角度差小于90度 的box - for *xyxy, conf, cls, flag in reversed(joint_det): - if flag == 0: - # label = f'{int(cls)} {conf:.2f}' - label = None - plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2) + # for *xyxy, conf, cls, flag in reversed(joint_det): + # if flag == 0: + # # label = f'{int(cls)} {conf:.2f}' + # label = None + # plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2) # save - save_path = os.path.join(outpth, file) - cv2.imwrite(save_path, img0) + # save_path = os.path.join(outpth, file) + # cv2.imwrite(save_path, file) diff --git a/DMPRUtils/DMPR_process.py b/DMPRUtils/DMPR_process.py index 0409deb..a3cea0d 100644 --- a/DMPRUtils/DMPR_process.py +++ b/DMPRUtils/DMPR_process.py @@ -133,12 +133,12 @@ def get_predicted_points2(prediction, thresh): def detect_marking_points(detector, image, thresh, device): """Given image read from opencv, return detected marking points.""" - t1 = time.time() - torch.cuda.synchronize(device) + # t1 = time.time() + # torch.cuda.synchronize(device) prediction = detector(preprocess_image(image).to(device)) - torch.cuda.synchronize(device) - t2 = time.time() - print(f'detector: {t2 - t1:.3f}s') + # torch.cuda.synchronize(device) + # t2 = time.time() + # print(f'detector: {t2 - t1:.3f}s') return get_predicted_points2(prediction[0], thresh) def scale_coords2(img1_shape, coords, img0_shape, ratio_pad=None): diff --git a/conf/config.py b/conf/config.py index e81a741..924485a 100644 --- a/conf/config.py +++ b/conf/config.py @@ -58,9 +58,9 @@ def get_parser_for_inference(): # STDC parser.add_argument('--n-classes', type=int, default=2, help='number of classes for segment') parser.add_argument('--backbone', type=str, default='STDCNet813', help='STDC backbone') - parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final.pth', + parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final_1023.pth', help='The weights of STDC') - parser.add_argument('--stdc-scale', type=float, default=0.75, help='The scale of STDC') + parser.add_argument('--stdc-new-hw', nargs='+', type=int, default=[360, 640], help='The new hw of STDC') parser.add_argument('--use-boundary-2', type=bool, default=False, help='') parser.add_argument('--use-boundary-4', type=bool, default=False, help='') parser.add_argument('--use-boundary-8', type=bool, default=False, help='')