diff --git a/.idea/AI.iml b/.idea/AI.iml
index 513351e..fb9d9f6 100644
--- a/.idea/AI.iml
+++ b/.idea/AI.iml
@@ -4,7 +4,7 @@
-
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
index da088d4..f148f03 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,4 +1,4 @@
-
+
\ No newline at end of file
diff --git a/AI_example.py b/AI_example.py
index 307df9e..1661cb1 100644
--- a/AI_example.py
+++ b/AI_example.py
@@ -20,7 +20,7 @@ from utils.torch_utils import select_device
def main():
##预先设置的参数
- device_ = '1' ##选定模型,可选 cpu,'0','1'
+ device_ = '0' ##选定模型,可选 cpu,'0','1'
##以下参数目前不可改
Detweights = 'weights/urbanManagement/yolo/best1023.pt'
@@ -79,12 +79,16 @@ def main():
t_yolo = time.time()
print(f't_yolo. ({t_yolo - t_start:.3f}s)')
+ t_stdc = time.time()
# STDC process
- det2 = STDC_process(img0, STDC_model, device, args.n_classes, args.stdc_scale)
-
+ det2 = STDC_process(img0, STDC_model, device, args.stdc_new_hw)
+ # det2[det2 == 1] = 255
+ t_stdc_inf = time.time()
+ print(f't_stdc_inf. ({t_stdc_inf - t_stdc:.3f}s)')
# STDC joint yolo
det0 = stdc_yolo(det2, det0)
-
+ t_stdc_yolo = time.time()
+ print(f't_stdc_joint. ({t_stdc_yolo - t_stdc_inf:.3f}s)')
# plot所有box
# for *xyxy, conf, cls in reversed(det0):
# label = f'{int(cls)} {conf:.2f}'
@@ -93,36 +97,36 @@ def main():
# DMPR process
det1 = DMPR_process(img0, DMPRmodel, device, args)
det1 = det1.cpu().detach().numpy()
-
+ #
t_dmpr = time.time()
print(f't_dmpr. ({t_dmpr - t_yolo:.3f}s)')
# 绘制角点
- plot_points(img0, det1)
+ # plot_points(img0, det1)
# yolo joint DMPR
cls = 0 #需要过滤的box类别
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio, args.border)
-
+ #
t_joint = time.time()
print(f't_joint. ({t_joint - t_dmpr:.3f}s)')
- # t_end = time.time()
- # print(f'Done. ({t_end - t_start:.3f}s)')
+ t_end = time.time()
+ print(f'Done. ({t_end - t_start:.3f}s)')
# 绘制膨胀box
- for *xyxy, flag in dilate_box:
- plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2)
- #
+ # for *xyxy, flag in dilate_box:
+ # plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2)
+ # #
# # 绘制删除满足 在膨胀框内 && 角度差小于90度 的box
- for *xyxy, conf, cls, flag in reversed(joint_det):
- if flag == 0:
- # label = f'{int(cls)} {conf:.2f}'
- label = None
- plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
+ # for *xyxy, conf, cls, flag in reversed(joint_det):
+ # if flag == 0:
+ # # label = f'{int(cls)} {conf:.2f}'
+ # label = None
+ # plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
# save
- save_path = os.path.join(outpth, file)
- cv2.imwrite(save_path, img0)
+ # save_path = os.path.join(outpth, file)
+ # cv2.imwrite(save_path, file)
diff --git a/DMPRUtils/DMPR_process.py b/DMPRUtils/DMPR_process.py
index 0409deb..a3cea0d 100644
--- a/DMPRUtils/DMPR_process.py
+++ b/DMPRUtils/DMPR_process.py
@@ -133,12 +133,12 @@ def get_predicted_points2(prediction, thresh):
def detect_marking_points(detector, image, thresh, device):
"""Given image read from opencv, return detected marking points."""
- t1 = time.time()
- torch.cuda.synchronize(device)
+ # t1 = time.time()
+ # torch.cuda.synchronize(device)
prediction = detector(preprocess_image(image).to(device))
- torch.cuda.synchronize(device)
- t2 = time.time()
- print(f'detector: {t2 - t1:.3f}s')
+ # torch.cuda.synchronize(device)
+ # t2 = time.time()
+ # print(f'detector: {t2 - t1:.3f}s')
return get_predicted_points2(prediction[0], thresh)
def scale_coords2(img1_shape, coords, img0_shape, ratio_pad=None):
diff --git a/conf/config.py b/conf/config.py
index e81a741..924485a 100644
--- a/conf/config.py
+++ b/conf/config.py
@@ -58,9 +58,9 @@ def get_parser_for_inference():
# STDC
parser.add_argument('--n-classes', type=int, default=2, help='number of classes for segment')
parser.add_argument('--backbone', type=str, default='STDCNet813', help='STDC backbone')
- parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final.pth',
+ parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final_1023.pth',
help='The weights of STDC')
- parser.add_argument('--stdc-scale', type=float, default=0.75, help='The scale of STDC')
+ parser.add_argument('--stdc-new-hw', nargs='+', type=int, default=[360, 640], help='The new hw of STDC')
parser.add_argument('--use-boundary-2', type=bool, default=False, help='')
parser.add_argument('--use-boundary-4', type=bool, default=False, help='')
parser.add_argument('--use-boundary-8', type=bool, default=False, help='')