Selaa lähdekoodia

1

add_stdc_seg
nyh 10 kuukautta sitten
vanhempi
commit
5dba294409
3 muutettua tiedostoa jossa 23 lisäystä ja 18 poistoa
  1. +19
    -14
      AI_example.py
  2. +1
    -1
      DMPRUtils/yolo_net.py
  3. +3
    -3
      conf/config.py

+ 19
- 14
AI_example.py Näytä tiedosto

@@ -2,6 +2,7 @@ import os
import time

import cv2
import numpy as np
import torch

from DMPRUtils.DMPR_process import DMPR_process, plot_points
@@ -23,9 +24,9 @@ def main():
device_ = '0' ##选定模型,可选 cpu,'0','1'

##以下参数目前不可改
Detweights = 'weights/urbanManagement/yolo/best1023.pt'
Detweights = 'weights/urbanManagement/yolo/best1201.pt'
seg_nclass = 2
DMPRweights = "weights/urbanManagement/DMPR/dp_detector_299_1023.pth"
DMPRweights = "weights/urbanManagement/DMPR/dp_detector_372_1204.pth"
conf_thres, iou_thres, classes = 0.25, 0.45, 3
labelnames = "weights/yolov5/class5/labelnames.json"
rainbows = [[0, 0, 255], [0, 255, 0], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 129, 0], [255, 0, 127],
@@ -62,7 +63,9 @@ def main():

# 图像测试
# impth = 'images/input'
impth = 'images/debug'
# impth = 'images/debug'
# impth = '/home/thsw/ssd/zjc/cityManagement_test'
impth = '/home/thsw/WJ/zjc/AI/images/pic2'
# impth = '/home/thsw/WJ/zjc/AI_old/images/input_0'
# outpth = 'images/output'
outpth = 'images/debug_out'
@@ -82,7 +85,7 @@ def main():
t_stdc = time.time()
# STDC process
det2 = STDC_process(img0, STDC_model, device, args.stdc_new_hw)
# det2[det2 == 1] = 255
det2[det2 == 1] = 255
t_stdc_inf = time.time()
print(f't_stdc_inf. ({t_stdc_inf - t_stdc:.3f}s)')
# STDC joint yolo
@@ -102,7 +105,7 @@ def main():
print(f't_dmpr. ({t_dmpr - t_yolo:.3f}s)')

# 绘制角点
# plot_points(img0, det1)
plot_points(img0, det1)

# yolo joint DMPR
cls = 0 #需要过滤的box类别
@@ -114,19 +117,21 @@ def main():
t_end = time.time()
print(f'Done. ({t_end - t_start:.3f}s)')
# 绘制膨胀box
# for *xyxy, flag in dilate_box:
# plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2)
for *xyxy, flag in dilate_box:
plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2)
# #
# # 绘制删除满足 在膨胀框内 && 角度差小于90度 的box
# for *xyxy, conf, cls, flag in reversed(joint_det):
# if flag == 0:
# # label = f'{int(cls)} {conf:.2f}'
# label = None
# plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
for *xyxy, conf, cls, flag in reversed(joint_det):
if flag == 0:
# label = f'{int(cls)} {conf:.2f}'
label = None
plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)

# save
# save_path = os.path.join(outpth, file)
# cv2.imwrite(save_path, file)
mask = det2[..., np.newaxis].repeat(3, 2)
img_seg = 0.3*mask + img0
save_path = os.path.join(outpth, file)
cv2.imwrite(save_path, img_seg)




+ 1
- 1
DMPRUtils/yolo_net.py Näytä tiedosto

@@ -56,7 +56,7 @@ class Detect(nn.Module):
# y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
# z.append(y.view(bs, -1, self.no))

prediction = self.m[1](x[1])
prediction = self.m[0](x[0])
point_pred, angle_pred = torch.split(prediction, 4, dim=1)
point_pred = torch.sigmoid(point_pred)
angle_pred = torch.tanh(angle_pred)

+ 3
- 3
conf/config.py Näytä tiedosto

@@ -7,7 +7,7 @@ NUM_FEATURE_MAP_CHANNEL = 6

def add_common_arguments(parser):
"""Add common arguments for training and inference."""
parser.add_argument('--detector_weights', default=r'E:\pycharmProject\DMPR-PS\weights\dp_detector_299.pth',
parser.add_argument('--detector_weights', default=r'E:\pycharmProject\DMPR-PS\weights\dp_detector_372_1204.pth',
help="The weights of pretrained detector.")
parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml',
help='model.yaml path')
@@ -53,12 +53,12 @@ def get_parser_for_inference():
help='class-agnostic NMS')
parser.add_argument('--classes', nargs='+', type=int,
help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--dmpr-thresh', type=float, default=0.3,
parser.add_argument('--dmpr-thresh', type=float, default=0.1,
help="Detection threshold.")
# STDC
parser.add_argument('--n-classes', type=int, default=2, help='number of classes for segment')
parser.add_argument('--backbone', type=str, default='STDCNet813', help='STDC backbone')
parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final_1023.pth',
parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final_1123.pth',
help='The weights of STDC')
parser.add_argument('--stdc-new-hw', nargs='+', type=int, default=[360, 640], help='The new hw of STDC')
parser.add_argument('--use-boundary-2', type=bool, default=False, help='')

Loading…
Peruuta
Tallenna