This commit is contained in:
parent
807111ce13
commit
219a97f315
|
|
@ -8,6 +8,9 @@ from DMPRUtils.DMPR_process import DMPR_process, plot_points
|
||||||
from DMPRUtils.model.detector import DirectionalPointDetector
|
from DMPRUtils.model.detector import DirectionalPointDetector
|
||||||
from DMPRUtils.yolo_net import Model
|
from DMPRUtils.yolo_net import Model
|
||||||
from DMPR_YOLO.jointUtil import dmpr_yolo
|
from DMPR_YOLO.jointUtil import dmpr_yolo
|
||||||
|
from STDCUtils.STDC_process import STDC_process
|
||||||
|
from STDCUtils.models.model_stages import BiSeNet
|
||||||
|
from STDC_YOLO.yolo_stdc_joint import stdc_yolo
|
||||||
from conf import config
|
from conf import config
|
||||||
from models.experimental import attempt_load
|
from models.experimental import attempt_load
|
||||||
from models.yolo_process import yolo_process
|
from models.yolo_process import yolo_process
|
||||||
|
|
@ -20,9 +23,9 @@ def main():
|
||||||
device_ = '1' ##选定模型,可选 cpu,'0','1'
|
device_ = '1' ##选定模型,可选 cpu,'0','1'
|
||||||
|
|
||||||
##以下参数目前不可改
|
##以下参数目前不可改
|
||||||
Detweights = 'weights/urbanManagement/yolo/best.pt'
|
Detweights = 'weights/urbanManagement/yolo/best1023.pt'
|
||||||
seg_nclass = 2
|
seg_nclass = 2
|
||||||
DMPRweights = "weights/urbanManagement/DMPR/dp_detector_299.pth"
|
DMPRweights = "weights/urbanManagement/DMPR/dp_detector_299_1023.pth"
|
||||||
conf_thres, iou_thres, classes = 0.25, 0.45, 3
|
conf_thres, iou_thres, classes = 0.25, 0.45, 3
|
||||||
labelnames = "weights/yolov5/class5/labelnames.json"
|
labelnames = "weights/yolov5/class5/labelnames.json"
|
||||||
rainbows = [[0, 0, 255], [0, 255, 0], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 129, 0], [255, 0, 127],
|
rainbows = [[0, 0, 255], [0, 255, 0], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 129, 0], [255, 0, 127],
|
||||||
|
|
@ -40,6 +43,17 @@ def main():
|
||||||
if half:
|
if half:
|
||||||
model.half()
|
model.half()
|
||||||
|
|
||||||
|
# STDC model
|
||||||
|
args = config.get_parser_for_inference().parse_args()
|
||||||
|
|
||||||
|
# STDC model
|
||||||
|
STDC_model = BiSeNet(backbone=args.backbone, n_classes=args.n_classes,
|
||||||
|
use_boundary_2=args.use_boundary_2, use_boundary_4=args.use_boundary_4,
|
||||||
|
use_boundary_8=args.use_boundary_8, use_boundary_16=args.use_boundary_16,
|
||||||
|
use_conv_last=args.use_conv_last).to(device)
|
||||||
|
STDC_model.load_state_dict(torch.load(args.respth))
|
||||||
|
STDC_model.eval()
|
||||||
|
|
||||||
# DMPR model
|
# DMPR model
|
||||||
args = config.get_parser_for_inference().parse_args()
|
args = config.get_parser_for_inference().parse_args()
|
||||||
# DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
|
# DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
|
||||||
|
|
@ -66,6 +80,12 @@ def main():
|
||||||
t_yolo = time.time()
|
t_yolo = time.time()
|
||||||
print(f't_yolo. ({t_yolo - t_start:.3f}s)')
|
print(f't_yolo. ({t_yolo - t_start:.3f}s)')
|
||||||
|
|
||||||
|
# STDC process
|
||||||
|
det2 = STDC_process(img0, STDC_model, device, args.n_classes, args.stdc_scale)
|
||||||
|
|
||||||
|
# STDC joint yolo
|
||||||
|
det0 = stdc_yolo(det2, det0)
|
||||||
|
|
||||||
# plot所有box
|
# plot所有box
|
||||||
# for *xyxy, conf, cls in reversed(det0):
|
# for *xyxy, conf, cls in reversed(det0):
|
||||||
# label = f'{int(cls)} {conf:.2f}'
|
# label = f'{int(cls)} {conf:.2f}'
|
||||||
|
|
|
||||||
|
|
@ -101,6 +101,5 @@ def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int, scale_ratio):
|
||||||
|
|
||||||
yolo_det_clone[yolo_det_clone[:, -2] == cls, -1] = res
|
yolo_det_clone[yolo_det_clone[:, -2] == cls, -1] = res
|
||||||
|
|
||||||
|
|
||||||
return yolo_det_clone, new_yolo_det
|
return yolo_det_clone, new_yolo_det
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,5 +53,17 @@ def get_parser_for_inference():
|
||||||
help='filter by class: --class 0, or --class 0 2 3')
|
help='filter by class: --class 0, or --class 0 2 3')
|
||||||
parser.add_argument('--dmpr-thresh', type=float, default=0.3,
|
parser.add_argument('--dmpr-thresh', type=float, default=0.3,
|
||||||
help="Detection threshold.")
|
help="Detection threshold.")
|
||||||
|
# STDC
|
||||||
|
parser.add_argument('--n-classes', type=int, default=2, help='number of classes for segment')
|
||||||
|
parser.add_argument('--backbone', type=str, default='STDCNet813', help='STDC backbone')
|
||||||
|
parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final.pth',
|
||||||
|
help='The weights of STDC')
|
||||||
|
parser.add_argument('--stdc-scale', type=float, default=0.75, help='The scale of STDC')
|
||||||
|
parser.add_argument('--use-boundary-2', type=bool, default=False, help='')
|
||||||
|
parser.add_argument('--use-boundary-4', type=bool, default=False, help='')
|
||||||
|
parser.add_argument('--use-boundary-8', type=bool, default=False, help='')
|
||||||
|
parser.add_argument('--use-boundary-16', type=bool, default=False, help='')
|
||||||
|
parser.add_argument('--use-conv-last', type=bool, default=False, help='')
|
||||||
|
|
||||||
add_common_arguments(parser)
|
add_common_arguments(parser)
|
||||||
return parser
|
return parser
|
||||||
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue