This commit is contained in:
parent
4c861391aa
commit
0240c28698
|
|
@ -4,7 +4,7 @@
|
|||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.8 (yolov5)" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.8 (yolov5_3090)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (yolov5)" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (yolov5_3090)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
|
|
@ -20,7 +20,7 @@ from utils.torch_utils import select_device
|
|||
|
||||
def main():
|
||||
##预先设置的参数
|
||||
device_ = '1' ##选定模型,可选 cpu,'0','1'
|
||||
device_ = '0' ##选定模型,可选 cpu,'0','1'
|
||||
|
||||
##以下参数目前不可改
|
||||
Detweights = 'weights/urbanManagement/yolo/best1023.pt'
|
||||
|
|
@ -79,12 +79,16 @@ def main():
|
|||
t_yolo = time.time()
|
||||
print(f't_yolo. ({t_yolo - t_start:.3f}s)')
|
||||
|
||||
t_stdc = time.time()
|
||||
# STDC process
|
||||
det2 = STDC_process(img0, STDC_model, device, args.n_classes, args.stdc_scale)
|
||||
|
||||
det2 = STDC_process(img0, STDC_model, device, args.stdc_new_hw)
|
||||
# det2[det2 == 1] = 255
|
||||
t_stdc_inf = time.time()
|
||||
print(f't_stdc_inf. ({t_stdc_inf - t_stdc:.3f}s)')
|
||||
# STDC joint yolo
|
||||
det0 = stdc_yolo(det2, det0)
|
||||
|
||||
t_stdc_yolo = time.time()
|
||||
print(f't_stdc_joint. ({t_stdc_yolo - t_stdc_inf:.3f}s)')
|
||||
# plot所有box
|
||||
# for *xyxy, conf, cls in reversed(det0):
|
||||
# label = f'{int(cls)} {conf:.2f}'
|
||||
|
|
@ -93,36 +97,36 @@ def main():
|
|||
# DMPR process
|
||||
det1 = DMPR_process(img0, DMPRmodel, device, args)
|
||||
det1 = det1.cpu().detach().numpy()
|
||||
|
||||
#
|
||||
t_dmpr = time.time()
|
||||
print(f't_dmpr. ({t_dmpr - t_yolo:.3f}s)')
|
||||
|
||||
# 绘制角点
|
||||
plot_points(img0, det1)
|
||||
# plot_points(img0, det1)
|
||||
|
||||
# yolo joint DMPR
|
||||
cls = 0 #需要过滤的box类别
|
||||
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio, args.border)
|
||||
|
||||
#
|
||||
t_joint = time.time()
|
||||
print(f't_joint. ({t_joint - t_dmpr:.3f}s)')
|
||||
|
||||
# t_end = time.time()
|
||||
# print(f'Done. ({t_end - t_start:.3f}s)')
|
||||
t_end = time.time()
|
||||
print(f'Done. ({t_end - t_start:.3f}s)')
|
||||
# 绘制膨胀box
|
||||
for *xyxy, flag in dilate_box:
|
||||
plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2)
|
||||
#
|
||||
# for *xyxy, flag in dilate_box:
|
||||
# plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2)
|
||||
# #
|
||||
# # 绘制删除满足 在膨胀框内 && 角度差小于90度 的box
|
||||
for *xyxy, conf, cls, flag in reversed(joint_det):
|
||||
if flag == 0:
|
||||
# label = f'{int(cls)} {conf:.2f}'
|
||||
label = None
|
||||
plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
|
||||
# for *xyxy, conf, cls, flag in reversed(joint_det):
|
||||
# if flag == 0:
|
||||
# # label = f'{int(cls)} {conf:.2f}'
|
||||
# label = None
|
||||
# plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
|
||||
|
||||
# save
|
||||
save_path = os.path.join(outpth, file)
|
||||
cv2.imwrite(save_path, img0)
|
||||
# save_path = os.path.join(outpth, file)
|
||||
# cv2.imwrite(save_path, file)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -133,12 +133,12 @@ def get_predicted_points2(prediction, thresh):
|
|||
|
||||
def detect_marking_points(detector, image, thresh, device):
|
||||
"""Given image read from opencv, return detected marking points."""
|
||||
t1 = time.time()
|
||||
torch.cuda.synchronize(device)
|
||||
# t1 = time.time()
|
||||
# torch.cuda.synchronize(device)
|
||||
prediction = detector(preprocess_image(image).to(device))
|
||||
torch.cuda.synchronize(device)
|
||||
t2 = time.time()
|
||||
print(f'detector: {t2 - t1:.3f}s')
|
||||
# torch.cuda.synchronize(device)
|
||||
# t2 = time.time()
|
||||
# print(f'detector: {t2 - t1:.3f}s')
|
||||
return get_predicted_points2(prediction[0], thresh)
|
||||
|
||||
def scale_coords2(img1_shape, coords, img0_shape, ratio_pad=None):
|
||||
|
|
|
|||
|
|
@ -58,9 +58,9 @@ def get_parser_for_inference():
|
|||
# STDC
|
||||
parser.add_argument('--n-classes', type=int, default=2, help='number of classes for segment')
|
||||
parser.add_argument('--backbone', type=str, default='STDCNet813', help='STDC backbone')
|
||||
parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final.pth',
|
||||
parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final_1023.pth',
|
||||
help='The weights of STDC')
|
||||
parser.add_argument('--stdc-scale', type=float, default=0.75, help='The scale of STDC')
|
||||
parser.add_argument('--stdc-new-hw', nargs='+', type=int, default=[360, 640], help='The new hw of STDC')
|
||||
parser.add_argument('--use-boundary-2', type=bool, default=False, help='')
|
||||
parser.add_argument('--use-boundary-4', type=bool, default=False, help='')
|
||||
parser.add_argument('--use-boundary-8', type=bool, default=False, help='')
|
||||
|
|
|
|||
Loading…
Reference in New Issue