nyh vor 1 Jahr
Ursprung
Commit
0240c28698
5 geänderte Dateien mit 32 neuen und 28 gelöschten Zeilen
  1. +1
    -1
      .idea/AI.iml
  2. +1
    -1
      .idea/misc.xml
  3. +23
    -19
      AI_example.py
  4. +5
    -5
      DMPRUtils/DMPR_process.py
  5. +2
    -2
      conf/config.py

+ 1
- 1
.idea/AI.iml Datei anzeigen

@@ -4,7 +4,7 @@
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/venv" />
</content>
<orderEntry type="jdk" jdkName="Python 3.8 (yolov5)" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Python 3.8 (yolov5_3090)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

+ 1
- 1
.idea/misc.xml Datei anzeigen

@@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (yolov5)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (yolov5_3090)" project-jdk-type="Python SDK" />
</project>

+ 23
- 19
AI_example.py Datei anzeigen

@@ -20,7 +20,7 @@ from utils.torch_utils import select_device

def main():
##预先设置的参数
device_ = '1' ##选定模型,可选 cpu,'0','1'
device_ = '0' ##选定模型,可选 cpu,'0','1'

##以下参数目前不可改
Detweights = 'weights/urbanManagement/yolo/best1023.pt'
@@ -79,12 +79,16 @@ def main():
t_yolo = time.time()
print(f't_yolo. ({t_yolo - t_start:.3f}s)')

t_stdc = time.time()
# STDC process
det2 = STDC_process(img0, STDC_model, device, args.n_classes, args.stdc_scale)

det2 = STDC_process(img0, STDC_model, device, args.stdc_new_hw)
# det2[det2 == 1] = 255
t_stdc_inf = time.time()
print(f't_stdc_inf. ({t_stdc_inf - t_stdc:.3f}s)')
# STDC joint yolo
det0 = stdc_yolo(det2, det0)

t_stdc_yolo = time.time()
print(f't_stdc_joint. ({t_stdc_yolo - t_stdc_inf:.3f}s)')
# plot所有box
# for *xyxy, conf, cls in reversed(det0):
# label = f'{int(cls)} {conf:.2f}'
@@ -93,36 +97,36 @@ def main():
# DMPR process
det1 = DMPR_process(img0, DMPRmodel, device, args)
det1 = det1.cpu().detach().numpy()
#
t_dmpr = time.time()
print(f't_dmpr. ({t_dmpr - t_yolo:.3f}s)')

# 绘制角点
plot_points(img0, det1)
# plot_points(img0, det1)

# yolo joint DMPR
cls = 0 #需要过滤的box类别
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio, args.border)
#
t_joint = time.time()
print(f't_joint. ({t_joint - t_dmpr:.3f}s)')

# t_end = time.time()
# print(f'Done. ({t_end - t_start:.3f}s)')
t_end = time.time()
print(f'Done. ({t_end - t_start:.3f}s)')
# 绘制膨胀box
for *xyxy, flag in dilate_box:
plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2)
#
# for *xyxy, flag in dilate_box:
# plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2)
# #
# # 绘制删除满足 在膨胀框内 && 角度差小于90度 的box
for *xyxy, conf, cls, flag in reversed(joint_det):
if flag == 0:
# label = f'{int(cls)} {conf:.2f}'
label = None
plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
# for *xyxy, conf, cls, flag in reversed(joint_det):
# if flag == 0:
# # label = f'{int(cls)} {conf:.2f}'
# label = None
# plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)

# save
save_path = os.path.join(outpth, file)
cv2.imwrite(save_path, img0)
# save_path = os.path.join(outpth, file)
# cv2.imwrite(save_path, file)




+ 5
- 5
DMPRUtils/DMPR_process.py Datei anzeigen

@@ -133,12 +133,12 @@ def get_predicted_points2(prediction, thresh):

def detect_marking_points(detector, image, thresh, device):
"""Given image read from opencv, return detected marking points."""
t1 = time.time()
torch.cuda.synchronize(device)
# t1 = time.time()
# torch.cuda.synchronize(device)
prediction = detector(preprocess_image(image).to(device))
torch.cuda.synchronize(device)
t2 = time.time()
print(f'detector: {t2 - t1:.3f}s')
# torch.cuda.synchronize(device)
# t2 = time.time()
# print(f'detector: {t2 - t1:.3f}s')
return get_predicted_points2(prediction[0], thresh)

def scale_coords2(img1_shape, coords, img0_shape, ratio_pad=None):

+ 2
- 2
conf/config.py Datei anzeigen

@@ -58,9 +58,9 @@ def get_parser_for_inference():
# STDC
parser.add_argument('--n-classes', type=int, default=2, help='number of classes for segment')
parser.add_argument('--backbone', type=str, default='STDCNet813', help='STDC backbone')
parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final.pth',
parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final_1023.pth',
help='The weights of STDC')
parser.add_argument('--stdc-scale', type=float, default=0.75, help='The scale of STDC')
parser.add_argument('--stdc-new-hw', nargs='+', type=int, default=[360, 640], help='The new hw of STDC')
parser.add_argument('--use-boundary-2', type=bool, default=False, help='')
parser.add_argument('--use-boundary-4', type=bool, default=False, help='')
parser.add_argument('--use-boundary-8', type=bool, default=False, help='')

Laden…
Abbrechen
Speichern