@@ -4,7 +4,7 @@ | |||
<content url="file://$MODULE_DIR$"> | |||
<excludeFolder url="file://$MODULE_DIR$/venv" /> | |||
</content> | |||
<orderEntry type="jdk" jdkName="Python 3.8 (yolov5)" jdkType="Python SDK" /> | |||
<orderEntry type="jdk" jdkName="Python 3.8 (yolov5_3090)" jdkType="Python SDK" /> | |||
<orderEntry type="sourceFolder" forTests="false" /> | |||
</component> | |||
</module> |
@@ -1,4 +1,4 @@ | |||
<?xml version="1.0" encoding="UTF-8"?> | |||
<project version="4"> | |||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (yolov5)" project-jdk-type="Python SDK" /> | |||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (yolov5_3090)" project-jdk-type="Python SDK" /> | |||
</project> |
@@ -20,7 +20,7 @@ from utils.torch_utils import select_device | |||
def main(): | |||
##预先设置的参数 | |||
device_ = '1' ##选定模型,可选 cpu,'0','1' | |||
device_ = '0' ##选定模型,可选 cpu,'0','1' | |||
##以下参数目前不可改 | |||
Detweights = 'weights/urbanManagement/yolo/best1023.pt' | |||
@@ -79,12 +79,16 @@ def main(): | |||
t_yolo = time.time() | |||
print(f't_yolo. ({t_yolo - t_start:.3f}s)') | |||
t_stdc = time.time() | |||
# STDC process | |||
det2 = STDC_process(img0, STDC_model, device, args.n_classes, args.stdc_scale) | |||
det2 = STDC_process(img0, STDC_model, device, args.stdc_new_hw) | |||
# det2[det2 == 1] = 255 | |||
t_stdc_inf = time.time() | |||
print(f't_stdc_inf. ({t_stdc_inf - t_stdc:.3f}s)') | |||
# STDC joint yolo | |||
det0 = stdc_yolo(det2, det0) | |||
t_stdc_yolo = time.time() | |||
print(f't_stdc_joint. ({t_stdc_yolo - t_stdc_inf:.3f}s)') | |||
# plot所有box | |||
# for *xyxy, conf, cls in reversed(det0): | |||
# label = f'{int(cls)} {conf:.2f}' | |||
@@ -93,36 +97,36 @@ def main(): | |||
# DMPR process | |||
det1 = DMPR_process(img0, DMPRmodel, device, args) | |||
det1 = det1.cpu().detach().numpy() | |||
# | |||
t_dmpr = time.time() | |||
print(f't_dmpr. ({t_dmpr - t_yolo:.3f}s)') | |||
# 绘制角点 | |||
plot_points(img0, det1) | |||
# plot_points(img0, det1) | |||
# yolo joint DMPR | |||
cls = 0 #需要过滤的box类别 | |||
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio, args.border) | |||
# | |||
t_joint = time.time() | |||
print(f't_joint. ({t_joint - t_dmpr:.3f}s)') | |||
# t_end = time.time() | |||
# print(f'Done. ({t_end - t_start:.3f}s)') | |||
t_end = time.time() | |||
print(f'Done. ({t_end - t_start:.3f}s)') | |||
# 绘制膨胀box | |||
for *xyxy, flag in dilate_box: | |||
plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2) | |||
# | |||
# for *xyxy, flag in dilate_box: | |||
# plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2) | |||
# # | |||
# # 绘制删除满足 在膨胀框内 && 角度差小于90度 的box | |||
for *xyxy, conf, cls, flag in reversed(joint_det): | |||
if flag == 0: | |||
# label = f'{int(cls)} {conf:.2f}' | |||
label = None | |||
plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2) | |||
# for *xyxy, conf, cls, flag in reversed(joint_det): | |||
# if flag == 0: | |||
# # label = f'{int(cls)} {conf:.2f}' | |||
# label = None | |||
# plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2) | |||
# save | |||
save_path = os.path.join(outpth, file) | |||
cv2.imwrite(save_path, img0) | |||
# save_path = os.path.join(outpth, file) | |||
# cv2.imwrite(save_path, file) | |||
@@ -133,12 +133,12 @@ def get_predicted_points2(prediction, thresh): | |||
def detect_marking_points(detector, image, thresh, device): | |||
"""Given image read from opencv, return detected marking points.""" | |||
t1 = time.time() | |||
torch.cuda.synchronize(device) | |||
# t1 = time.time() | |||
# torch.cuda.synchronize(device) | |||
prediction = detector(preprocess_image(image).to(device)) | |||
torch.cuda.synchronize(device) | |||
t2 = time.time() | |||
print(f'detector: {t2 - t1:.3f}s') | |||
# torch.cuda.synchronize(device) | |||
# t2 = time.time() | |||
# print(f'detector: {t2 - t1:.3f}s') | |||
return get_predicted_points2(prediction[0], thresh) | |||
def scale_coords2(img1_shape, coords, img0_shape, ratio_pad=None): |
@@ -58,9 +58,9 @@ def get_parser_for_inference(): | |||
# STDC | |||
parser.add_argument('--n-classes', type=int, default=2, help='number of classes for segment') | |||
parser.add_argument('--backbone', type=str, default='STDCNet813', help='STDC backbone') | |||
parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final.pth', | |||
parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final_1023.pth', | |||
help='The weights of STDC') | |||
parser.add_argument('--stdc-scale', type=float, default=0.75, help='The scale of STDC') | |||
parser.add_argument('--stdc-new-hw', nargs='+', type=int, default=[360, 640], help='The new hw of STDC') | |||
parser.add_argument('--use-boundary-2', type=bool, default=False, help='') | |||
parser.add_argument('--use-boundary-4', type=bool, default=False, help='') | |||
parser.add_argument('--use-boundary-8', type=bool, default=False, help='') |